Assignment 2 Music Generation¶

Task 1: Symbolic Unconditioned Generation¶

Train a RNN LSTM on 4 melodies (p(x)) and sample new sequences.¶

Installation & Imports¶

In [ ]:
# Install required libraries
# !pip install torch pretty_midi matplotlib midi2audio librosa
Requirement already satisfied: torch in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (2.7.0)
Requirement already satisfied: pretty_midi in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (0.2.10)
Collecting matplotlib
  Downloading matplotlib-3.10.3-cp313-cp313-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting midi2audio
  Using cached midi2audio-0.1.1-py2.py3-none-any.whl.metadata (5.7 kB)
Collecting librosa
  Using cached librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Requirement already satisfied: filelock in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (3.18.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (4.13.2)
Requirement already satisfied: setuptools in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (78.1.1)
Requirement already satisfied: sympy>=1.13.3 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (3.5)
Requirement already satisfied: jinja2 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from torch) (2025.5.1)
Requirement already satisfied: numpy>=1.7.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from pretty_midi) (2.2.6)
Requirement already satisfied: mido>=1.1.16 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from pretty_midi) (1.3.3)
Requirement already satisfied: six in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from pretty_midi) (1.17.0)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp313-cp313-macosx_11_0_arm64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.58.1-cp313-cp313-macosx_10_13_universal2.whl.metadata (106 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp313-cp313-macosx_11_0_arm64.whl.metadata (6.2 kB)
Requirement already satisfied: packaging>=20.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (25.0)
Collecting pillow>=8 (from matplotlib)
  Downloading pillow-11.2.1-cp313-cp313-macosx_11_0_arm64.whl.metadata (8.9 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Using cached pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)
Requirement already satisfied: python-dateutil>=2.7 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (2.9.0.post0)
Collecting audioread>=2.1.9 (from librosa)
  Using cached audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.61.2-cp313-cp313-macosx_11_0_arm64.whl.metadata (2.7 kB)
Collecting scipy>=1.6.0 (from librosa)
  Downloading scipy-1.15.3-cp313-cp313-macosx_14_0_arm64.whl.metadata (61 kB)
Collecting scikit-learn>=1.1.0 (from librosa)
  Downloading scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl.metadata (31 kB)
Collecting joblib>=1.0 (from librosa)
  Downloading joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)
Requirement already satisfied: decorator>=4.3.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from librosa) (5.2.1)
Collecting soundfile>=0.12.1 (from librosa)
  Using cached soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Using cached pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Using cached soxr-0.5.0.post1-cp312-abi3-macosx_11_0_arm64.whl.metadata (5.6 kB)
Collecting lazy_loader>=0.1 (from librosa)
  Using cached lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (8.4 kB)
Collecting standard-aifc (from librosa)
  Downloading standard_aifc-3.13.0-py3-none-any.whl.metadata (969 bytes)
Collecting standard-sunau (from librosa)
  Downloading standard_sunau-3.13.0-py3-none-any.whl.metadata (914 bytes)
Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.44.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (4.8 kB)
Requirement already satisfied: platformdirs>=2.5.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from pooch>=1.1->librosa) (4.3.8)
Collecting requests>=2.19.0 (from pooch>=1.1->librosa)
  Using cached requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting charset-normalizer<4,>=2 (from requests>=2.19.0->pooch>=1.1->librosa)
  Downloading charset_normalizer-3.4.2-cp313-cp313-macosx_10_13_universal2.whl.metadata (35 kB)
Collecting idna<4,>=2.5 (from requests>=2.19.0->pooch>=1.1->librosa)
  Using cached idna-3.10-py3-none-any.whl.metadata (10 kB)
Collecting urllib3<3,>=1.21.1 (from requests>=2.19.0->pooch>=1.1->librosa)
  Using cached urllib3-2.4.0-py3-none-any.whl.metadata (6.5 kB)
Collecting certifi>=2017.4.17 (from requests>=2.19.0->pooch>=1.1->librosa)
  Downloading certifi-2025.4.26-py3-none-any.whl.metadata (2.5 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn>=1.1.0->librosa)
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Collecting cffi>=1.0 (from soundfile>=0.12.1->librosa)
  Downloading cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl.metadata (1.5 kB)
Collecting pycparser (from cffi>=1.0->soundfile>=0.12.1->librosa)
  Using cached pycparser-2.22-py3-none-any.whl.metadata (943 bytes)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from jinja2->torch) (3.0.2)
Collecting standard-chunk (from standard-aifc->librosa)
  Downloading standard_chunk-3.13.0-py3-none-any.whl.metadata (860 bytes)
Collecting audioop-lts (from standard-aifc->librosa)
  Downloading audioop_lts-0.2.1-cp313-abi3-macosx_11_0_arm64.whl.metadata (1.6 kB)
Downloading matplotlib-3.10.3-cp313-cp313-macosx_11_0_arm64.whl (8.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.1/8.1 MB 2.6 MB/s eta 0:00:00a 0:00:01
Using cached midi2audio-0.1.1-py2.py3-none-any.whl (8.7 kB)
Using cached librosa-0.11.0-py3-none-any.whl (260 kB)
Using cached audioread-3.0.1-py3-none-any.whl (23 kB)
Downloading contourpy-1.3.2-cp313-cp313-macosx_11_0_arm64.whl (255 kB)
Using cached cycler-0.12.1-py3-none-any.whl (8.3 kB)
Downloading fonttools-4.58.1-cp313-cp313-macosx_10_13_universal2.whl (2.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 4.4 MB/s eta 0:00:00a 0:00:01
Downloading joblib-1.5.1-py3-none-any.whl (307 kB)
Downloading kiwisolver-1.4.8-cp313-cp313-macosx_11_0_arm64.whl (65 kB)
Using cached lazy_loader-0.4-py3-none-any.whl (12 kB)
Downloading msgpack-1.1.0-cp313-cp313-macosx_11_0_arm64.whl (81 kB)
Downloading numba-0.61.2-cp313-cp313-macosx_11_0_arm64.whl (2.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.8/2.8 MB 6.3 MB/s eta 0:00:00a 0:00:01
Downloading llvmlite-0.44.0-cp313-cp313-macosx_11_0_arm64.whl (26.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 26.2/26.2 MB 5.4 MB/s eta 0:00:00a 0:00:01
Downloading pillow-11.2.1-cp313-cp313-macosx_11_0_arm64.whl (3.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.0/3.0 MB 6.1 MB/s eta 0:00:00a 0:00:01
Using cached pooch-1.8.2-py3-none-any.whl (64 kB)
Using cached pyparsing-3.2.3-py3-none-any.whl (111 kB)
Using cached requests-2.32.3-py3-none-any.whl (64 kB)
Downloading charset_normalizer-3.4.2-cp313-cp313-macosx_10_13_universal2.whl (199 kB)
Using cached idna-3.10-py3-none-any.whl (70 kB)
Using cached urllib3-2.4.0-py3-none-any.whl (128 kB)
Downloading certifi-2025.4.26-py3-none-any.whl (159 kB)
Downloading scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl (11.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.1/11.1 MB 4.7 MB/s eta 0:00:00a 0:00:01
Downloading scipy-1.15.3-cp313-cp313-macosx_14_0_arm64.whl (22.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.4/22.4 MB 6.5 MB/s eta 0:00:0000:0100:01
Using cached soundfile-0.13.1-py2.py3-none-macosx_11_0_arm64.whl (1.1 MB)
Downloading cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl (178 kB)
Using cached soxr-0.5.0.post1-cp312-abi3-macosx_11_0_arm64.whl (156 kB)
Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Using cached pycparser-2.22-py3-none-any.whl (117 kB)
Downloading standard_aifc-3.13.0-py3-none-any.whl (10 kB)
Downloading audioop_lts-0.2.1-cp313-abi3-macosx_11_0_arm64.whl (26 kB)
Downloading standard_chunk-3.13.0-py3-none-any.whl (4.9 kB)
Downloading standard_sunau-3.13.0-py3-none-any.whl (7.4 kB)
Installing collected packages: standard-chunk, midi2audio, urllib3, threadpoolctl, soxr, scipy, pyparsing, pycparser, pillow, msgpack, llvmlite, lazy_loader, kiwisolver, joblib, idna, fonttools, cycler, contourpy, charset-normalizer, certifi, audioread, audioop-lts, standard-sunau, standard-aifc, scikit-learn, requests, numba, matplotlib, cffi, soundfile, pooch, librosa
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 32/32 [librosa]1/32 [librosa]ib]n]
Successfully installed audioop-lts-0.2.1 audioread-3.0.1 certifi-2025.4.26 cffi-1.17.1 charset-normalizer-3.4.2 contourpy-1.3.2 cycler-0.12.1 fonttools-4.58.1 idna-3.10 joblib-1.5.1 kiwisolver-1.4.8 lazy_loader-0.4 librosa-0.11.0 llvmlite-0.44.0 matplotlib-3.10.3 midi2audio-0.1.1 msgpack-1.1.0 numba-0.61.2 pillow-11.2.1 pooch-1.8.2 pycparser-2.22 pyparsing-3.2.3 requests-2.32.3 scikit-learn-1.6.1 scipy-1.15.3 soundfile-0.13.1 soxr-0.5.0.post1 standard-aifc-3.13.0 standard-chunk-3.13.0 standard-sunau-3.13.0 threadpoolctl-3.6.0 urllib3-2.4.0
In [8]:
# Imports
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import pretty_midi
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import librosa
import collections
import math

from torch.utils.data import Dataset, DataLoader
from midi2audio import FluidSynth
from IPython.display import Audio, display
from matplotlib.ticker import MaxNLocator

Data Loading¶

In [3]:
# Load the pre‐serialized JSB Chorales dataset
with open("JSB-Chorales-dataset-master/jsb-chorales-quarter.pkl", "rb") as f:
    data = pickle.load(f, encoding="latin1")

# We’ll work with the training split here, you can also access 'valid' and 'test'
chorales = data["train"]
print(f"Loaded {len(chorales)} training chorales.")
print("Sample:", chorales[0][:5])
Loaded 229 training chorales.
Sample: [(60, 72, 79, 88), (72, 79, 88), (67, 70, 76, 84), (69, 77, 86), (67, 70, 79, 88)]

Dataset Context¶

The JSB Chorales dataset consists of 382 four-part harmonized chorales by J.S. Bach. It is widely used in symbolic music modeling and has been curated to support machine learning tasks. We use the version released by Zhuang et al., which is represented as a sequence of four‐voice chord events (soprano, alto, tenor, bass), quantized to quarter‐note durations.

Instead of modeling only the soprano line, we now build a polyphonic model that learns full four‐voice chorales in parallel. At each time step, the model will predict an entire 4‐tuple of MIDI pitches (or rests) for all voices simultaneously.

Preprocessing Steps¶

  1. Extract four‐voice chord tuples

    • For each chorale, read each 4‐element chord event (one MIDI pitch per voice).
    • Skip any chord where all four voices are rests (-1, -1, -1, -1).
    • Drop any chorale that has fewer than 10 valid chords.
  2. Build a chord vocabulary

    • Collect the set of all unique 4‐tuples (soprano, alto, tenor, bass) across the training split.
    • Map each unique chord‐tuple to a distinct integer index.
  3. Tokenize each chorale as a sequence of chord‐indices

    • Convert each 4‐tuple in a chorale to its index in the chord vocabulary.
    • Discard any chord not found in the vocabulary (e.g., if it only appeared in validation/test).
  4. Prepare sequence‐to‐sequence training pairs

    • Slide a fixed‐length window (e.g., 32 chords) over each tokenized chord sequence.
    • For each window, the input is the first 32 chord‐indices, and the target is the next 32 chord‐indices (shifted by one).
  5. Build ChordSequenceDataset and DataLoader

    • Wrap the tokenized sequences of indices in a PyTorch Dataset that returns (input_seq, target_seq) pairs.
    • Use a DataLoader with a suitable batch size (e.g., 64) to feed the LSTM.

After these steps, we feed full four‐voice chord sequences into our MusicRNN model so that at each step it learns to predict a 4‐voice chord rather than a single monophonic melody.

In [ ]:
# We build a sequence of 4‐tuples for all 4 harmonies: soprano, alto, tenor, bass.
# We skip any chord that is all rests (-1 in every voice), and drop very short chorales.
chord_seqs = []

for chorale in chorales:
    chord_list = []
    for chord in chorale:
        # Chord is either a list/tuple of length 4, or -1 for a complete rest
        if isinstance(chord, (list, tuple)) and len(chord) == 4:
            # Convert any numpy types to int and keep the 4‐tuple as is:
            chord_tuple = (int(chord[0]), int(chord[1]), int(chord[2]), int(chord[3]))
            # If the chord is NOT four rests, we keep it.  (If all four voices are -1, skip.)
            if chord_tuple != (-1, -1, -1, -1):
                chord_list.append(chord_tuple)
    # Only keep chorales longer than 10 chords
    if len(chord_list) > 10:
        chord_seqs.append(chord_list)

print(f"Extracted {len(chord_seqs)} four‐voice sequences.")
print("Example chord‐sequence (first 5 chords):", chord_seqs[0][:5])
Extracted 229 four‐voice sequences.
Example chord‐sequence (first 5 chords): [(60, 72, 79, 88), (67, 70, 76, 84), (67, 70, 79, 88), (65, 72, 81, 89), (65, 72, 81, 89)]

Vocabulary & Tokenization¶

In [ ]:
# Build a set of all unique 4‐tuples (chords) in the training split.

all_chords = sorted({tuple(chord) for seq in chord_seqs for chord in seq})
# Map each chord‐tuple to a unique integer index
chord_to_idx = {chord: i for i, chord in enumerate(all_chords)}
idx_to_chord = {i: chord for chord, i in chord_to_idx.items()}
vocab_size = len(chord_to_idx)

# Convert each chord‐tuple sequence into a list of indices
tokenized_chord_seqs = [[chord_to_idx[ch] for ch in seq] for seq in chord_seqs]

print("Four‐voice chord vocabulary size:", vocab_size)
print("Tokenized example (first 10 chord‐indices):", tokenized_chord_seqs[0][:10])
Four‐voice chord vocabulary size: 2113
Tokenized example (first 10 chord‐indices): [736, 1496, 1502, 1338, 1338, 537, 1697, 1634, 1704, 1445]

Dataset Class¶

In [ ]:
# Create Dataset class for LSTM training. 
# Takes tokenized melody sequences and splits into
# fixed-length input-output pairs.
class ChordSequenceDataset(Dataset):
    def __init__(self, token_chord_seqs, seq_len=32):
        super().__init__()
        self.samples = []
        # Slide a window of length seq_len over each chord‐token sequence
        for seq in token_chord_seqs:
            for i in range(len(seq) - seq_len):
                x = seq[i : i + seq_len]           # input: a sequence of chord‐indices
                y = seq[i + 1 : i + seq_len + 1]   # target: next‐chord at each step
                self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        # Return LongTensors of shape (seq_len,) of chord‐indices
        return torch.tensor(x, dtype=torch.long), \
               torch.tensor(y, dtype=torch.long)

DataLoader Preparation¶

In [ ]:
# Create batches of (input, target) pairs for training.
seq_len  = 32 # length of each input sequence (tries to predict 32 next notes)
batch_size = 64 # number of sequences per batch (process 64 input-output pairs at a time)

# Create dataset and dataloader
dataset = ChordSequenceDataset(tokenized_chord_seqs, seq_len=seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(f"Total training chord‐sequences: {len(dataset)}")
Total training chord‐sequences: 5186

Training Model¶

In [ ]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2, seq_len=32):
        super(MusicRNN, self).__init__()
        # Embedding now maps each chord‐index to a dense vector
        self.embedding      = nn.Embedding(vocab_size, embedding_dim)
        # Positional embeddings add information about each timestep's position
        self.position_embed = nn.Embedding(seq_len, embedding_dim)

         # LSTM stack: processes the embedded sequence, with dropout between layers
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True, # input/output tensors have shape (batch, seq, feature)
            dropout=0.2 # dropout on outputs of all layers except the last
        )

        self.norm    = nn.LayerNorm(hidden_dim) # LayerNorm stabilizes the activations before the final layers
        self.dropout = nn.Dropout(0.3) # Dropout after LSTM to reduce overfitting
        self.fc      = nn.Linear(hidden_dim, vocab_size) # Final linear layer maps hidden states to vocabulary logits

    def forward(self, x):
        batch_size, seqlen = x.size()
        # Create a tensor of positions [0, 1, ..., seq_len-1] for each example
        positions = (torch.arange(seqlen, device=x.device)
                        .unsqueeze(0)
                        .expand(batch_size, seqlen))
        embeddings = self.embedding(x) + self.position_embed(positions)

        out, _     = self.lstm(embeddings)
        out        = self.norm(out)
        out        = self.dropout(out)
        
        logits     = self.fc(out)  # shape: (batch_size, seqlen, vocab_size)
        return logits
In [9]:
def train_rnn(model, dataloader, vocab_size, num_epochs=10, lr=0.001,
              device="cuda" if torch.cuda.is_available() else "cpu"):
    """
    Train the MusicRNN on the provided dataloader.

    model: instance of MusicRNN
    dataloader: yields (input_batch, target_batch) pairs
    vocab_size: size of the token vocabulary for loss calculation
    """
    model     = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn   = nn.CrossEntropyLoss()
    # Scheduler reduces LR by 0.5 if validation loss hasn't improved for 2 epochs
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, mode="min", factor=0.5, patience=2)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for xb, yb in dataloader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()

            # Forward pass: get logits of shape (batch, seq_len, vocab_size)
            logits = model(xb)

            # Compute cross-entropy loss across all timesteps
            loss = loss_fn(
                logits.view(-1, vocab_size),   # (batch*seq_len, vocab_size)
                yb.view(-1)                     # (batch*seq_len,)
            )

            # Backward pass and gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update model parameters
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f}")

        # Step the scheduler with the average training loss
        scheduler.step(avg_loss)
In [10]:
# Trains the Model for 10 epochs
model = MusicRNN(vocab_size=vocab_size, seq_len=32)
train_rnn(model, dataloader, vocab_size, num_epochs=10)
Epoch 1/10 | Loss: 5.1666
Epoch 2/10 | Loss: 2.5353
Epoch 3/10 | Loss: 1.3446
Epoch 4/10 | Loss: 0.8037
Epoch 5/10 | Loss: 0.5484
Epoch 6/10 | Loss: 0.4196
Epoch 7/10 | Loss: 0.3458
Epoch 8/10 | Loss: 0.2947
Epoch 9/10 | Loss: 0.2619
Epoch 10/10 | Loss: 0.2408

Sampling from the trained LSTM¶

In [ ]:
# 3 samples: (A) a random 4-note prefix, (B) a single-note "cold" start, or (C) a very short seed.

def sample_diverse(
    model,
    tokenized_seqs,
    max_length=64,
    prefix_type="random_short",  # "random_short", "single", or "fixed"
    fixed_prefix=None,           # only used if prefix_type=="fixed"
    prefix_len=4, 
    first_steps_temp=2.0,        # high temp for initial steps
    normal_temp=1.0,
    top_k=5,
    top_p=0.8,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    """
    prefix_type:
      - "fixed": uses fixed_prefix (list of IDs)
      - "random_short": picks a random melody and takes prefix_len tokens
      - "single": starts from 1 random token
    """
    model.eval().to(device)
    
    # Pick our seed
    if prefix_type == "fixed":
        assert fixed_prefix is not None
        prefix = fixed_prefix
    elif prefix_type == "random_short":
        seq = random.choice(tokenized_seqs)
        prefix = seq[:prefix_len]
    elif prefix_type == "single":
        prefix = [ random.choice(tokenized_seqs)[0] ]
    else:
        raise ValueError("bad prefix_type")

    generated = prefix[:]
    input_seq = torch.tensor([generated], device=device)
    
    def filter_logits(logits):
        from torch.nn.functional import softmax
        logits = logits.clone()
        # Top-k
        if top_k>0:
            kth = torch.topk(logits, top_k)[0][-1]
            logits[logits <  kth] = -1e9
        # Top-p
        if top_p>0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            cum = softmax(sorted_logits, dim=-1).cumsum(dim=-1)
            mask = cum > top_p
            mask[...,1:] = mask[...,:-1].clone()
            mask[...,0]  = False
            logits[ sorted_idx[mask] ] = -1e9
        return logits

    for i in range(max_length - len(prefix)):
        # Choose temperature
        temp = first_steps_temp if i < len(prefix) else normal_temp
        
        seq_len = model.position_embed.num_embeddings
        inp = input_seq[:, -seq_len:]
        logits = model(inp)[0, -1, :] / temp
        filt   = filter_logits(logits)
        probs  = F.softmax(filt, dim=-1)
        nxt    = torch.multinomial(probs, 1).item()

        generated.append(nxt)
        input_seq = torch.tensor([generated], device=device)

    return generated

# Try all three strategies:
gens = {}
gens["A_random4"] = sample_diverse(
    model,
    tokenized_chord_seqs,       
    prefix_type="random_short",
    prefix_len=4
)
gens["B_single"] = sample_diverse(
    model,
    tokenized_chord_seqs,
    prefix_type="single"
)
gens["C_fixed4"] = sample_diverse(
    model,
    tokenized_chord_seqs,
    prefix_type="fixed",
    fixed_prefix=tokenized_chord_seqs[0][:4]  # first 4 chords of the first chorale
)

# Now map each generated chord-index sequence back to actual 4-tuples
chord_sequences = {
    name: [idx_to_chord[idx] for idx in seq]
    for name, seq in gens.items()
}

# 3 different generated strategies
generated_chords   = chord_sequences["A_random4"]
generated_chords2  = chord_sequences["B_single"]
generated_chords3  = chord_sequences["C_fixed4"]

Save original & generated as MIDI and convert to WAV for listening¶

In [ ]:
# Helper function to write a list of MIDI pitches to a .mid file
# with all four voice‐notes in parallel at each time step.
def save_four_voice_midi(chord_seq, filename="polyphonic_output.mid", note_duration=0.5):
    pm = pretty_midi.PrettyMIDI()
    instr = pretty_midi.Instrument(program=0)  # single piano instrument
    current_time = 0.0

    for item in chord_seq:
        if isinstance(item, tuple):
            chord_tuple = item
        else:
            # assume 'item' is an index
            chord_tuple = idx_to_chord[item]

        for pitch in chord_tuple:
            if pitch != -1:
                note = pretty_midi.Note(
                    velocity=100,
                    pitch=pitch,
                    start=current_time,
                    end=current_time + note_duration
                )
                instr.notes.append(note)

        current_time += note_duration

    pm.instruments.append(instr)
    pm.write(filename)
In [ ]:
# Convert chord-indices → write a four-voice MIDI
save_four_voice_midi(generated_chords,  filename="generated_chords_A.mid")
save_four_voice_midi(generated_chords2, filename="generated_chords_B.mid")
save_four_voice_midi(generated_chords3, filename="generated_chords_C.mid")

# Convert original 4-voice (first 64 chords) and each generated version to WAV
save_four_voice_midi(tokenized_chord_seqs[0][:64], filename="original_chords.mid")
fs = FluidSynth("FluidR3_GM.sf2")
fs.midi_to_audio("original_chords.mid",      "original_chords.wav")
fs.midi_to_audio("generated_chords_A.mid",   "generated_A.wav")
fs.midi_to_audio("generated_chords_B.mid",   "generated_B.wav")
fs.midi_to_audio("generated_chords_C.mid",   "generated_C.wav")

# Play back and display audio in notebook
print("🎹 Original four-voice:")
display(Audio("original_chords.wav"))

print("🎹 Generated (A: random4):")
display(Audio("generated_A.wav"))

print("🎹 Generated (B: single-chord cold start):")
display(Audio("generated_B.wav"))

print("🎹 Generated (C: fixed4 prefix):")
display(Audio("generated_C.wav"))
FluidSynth runtime version 2.3.5
Copyright (C) 2000-2024 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'original_chords.wav'..
FluidSynth runtime version 2.3.5
Copyright (C) 2000-2024 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'generated_A.wav'..
FluidSynth runtime version 2.3.5
Copyright (C) 2000-2024 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'generated_B.wav'..
FluidSynth runtime version 2.3.5
Copyright (C) 2000-2024 Peter Hanappe and others.
Distributed under the LGPL license.
SoundFont(R) is a registered trademark of Creative Technology Ltd.

Rendering audio to file 'generated_C.wav'..
🎹 Original four-voice:
Your browser does not support the audio element.
🎹 Generated (A: random4):
Your browser does not support the audio element.
🎹 Generated (B: single-chord cold start):
Your browser does not support the audio element.
🎹 Generated (C: fixed4 prefix):
Your browser does not support the audio element.

Extract pitches back from files for plotting¶

In [ ]:
def extract_midi_pitches(midi_file):
    """Load MIDI and return a list of (start_time, pitch) sorted by time."""
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    pitches = []
    for note in midi_data.instruments[0].notes:
        pitches.append((note.start, note.pitch))
    # Sort by start time and return pitch only
    pitches.sort()
    
    return [p[1] for p in pitches]

def extract_pitch_from_wav(wav_file):
    """Use librosa’s pitch tracker to extract a MIDI‐rounded pitch contour."""
    y, sr = librosa.load(wav_file)
    pitches, magnitudes = librosa.piptrack(y=y, sr=sr)
    pitch_track = []
    for i in range(pitches.shape[1]):
        index = magnitudes[:, i].argmax()
        pitch = pitches[index, i]
        pitch_track.append(pitch if pitch > 0 else np.nan)
    # Convert Hz to MIDI pitch (round)
    midi_pitches = [int(round(librosa.hz_to_midi(p))) if not np.isnan(p) else np.nan for p in pitch_track]
    return midi_pitches

def plot_waveform(wav_file, ax, title="Waveform"):
    """Load a WAV file and plot its waveform on the given Axes."""
    y, sr = librosa.load(wav_file, sr=None)  # preserve native sample rate
    times = np.arange(len(y)) / sr
    ax.plot(times, y, linewidth=0.5)
    ax.set_title(title)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Amplitude")
    ax.grid(True, linestyle='--', alpha=0.5)

def plot_spectrogram(wav_file, ax, title="Spectrogram"):
    """Load a WAV file and plot its log-power spectrogram on the given Axes."""
    y, sr = librosa.load(wav_file, sr=None)
    # Compute short-time Fourier transform
    D = librosa.stft(y, n_fft=1024, hop_length=512)
    # Convert to decibels
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    img = librosa.display.specshow(
        S_db,
        sr=sr,
        hop_length=512,
        x_axis='time',
        y_axis='hz',
        ax=ax
    )
    ax.set_title(title)
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Frequency (Hz)")
    # Add a colorbar on the right of this axis
    plt.colorbar(img, ax=ax, format="%+2.0f dB")

Compare Original vs Generated A (random4): Waveform & Spectrogram¶

In [ ]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=False)

# Original waveform (top-left)
plot_waveform("original.wav",     axes[0, 0], title="Original Waveform")
# Generated A waveform (top-right)
plot_waveform("generated_A.wav",   axes[0, 1], title="Generated A Waveform")

# Original spectrogram (bottom-left)
plot_spectrogram("original.wav",   axes[1, 0], title="Original Spectrogram")
# Generated A spectrogram (bottom-right)
plot_spectrogram("generated_A.wav", axes[1, 1], title="Generated A Spectrogram")

plt.tight_layout()
plt.show()
No description has been provided for this image

Compare Original vs Generated B (single‐note): Waveform & Spectrogram¶

In [ ]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=False)

# Original waveform (top-left)
plot_waveform("original.wav",      axes[0, 0], title="Original Waveform")
# Generated B waveform (top-right)
plot_waveform("generated_B.wav",    axes[0, 1], title="Generated B Waveform")

# Original spectrogram (bottom-left)
plot_spectrogram("original.wav",    axes[1, 0], title="Original Spectrogram")
# Generated B spectrogram (bottom-right)
plot_spectrogram("generated_B.wav",  axes[1, 1], title="Generated B Spectrogram")

plt.tight_layout()
plt.show()
No description has been provided for this image

Compare Original vs Generated C (fixed4‐prefix): Waveform & Spectrogram¶

In [ ]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=False)

# Original waveform (top-left)
plot_waveform("original.wav",      axes[0, 0], title="Original Waveform")
# Generated C waveform (top-right)
plot_waveform("generated_C.wav",    axes[0, 1], title="Generated C Waveform")

# Original spectrogram (bottom-left)
plot_spectrogram("original.wav",    axes[1, 0], title="Original Spectrogram")
# Generated C spectrogram (bottom-right)
plot_spectrogram("generated_C.wav",  axes[1, 1], title="Generated C Spectrogram")

plt.tight_layout()
plt.show()
No description has been provided for this image

Evaluation (Four‐Voice Chord Model)¶

1. Chord‐Level Cross‐Entropy Loss and Perplexity¶

We evaluate the four‐voice model on held‐out splits:

  1. Preprocess validation and test chorales into chord‐tuples.
  2. Tokenize each chord tuple with chord_to_idx.
  3. Build ChordSequenceDataset / DataLoader for each split.
  4. Use our evaluate() helper to compute chord‐level average cross‐entropy loss and perplexity.
  5. Print validation & test metrics.
In [ ]:
# 1. Preprocess valid & test splits into four‐voice chord sequences
def extract_chord_seqs(chorales, min_len=10):
    seqs = []
    for chorale in chorales:
        chord_list = []
        for chord in chorale:
            # Chord is a 4‐tuple or -1; keep only valid 4‐tuples
            if isinstance(chord, (list, tuple)) and len(chord) == 4:
                chord_tuple = (int(chord[0]), int(chord[1]), int(chord[2]), int(chord[3]))
                if chord_tuple != (-1, -1, -1, -1):
                    chord_list.append(chord_tuple)
        if len(chord_list) > min_len:
            seqs.append(chord_list)
    return seqs

valid_chord_seqs = extract_chord_seqs(data["valid"])
test_chord_seqs  = extract_chord_seqs(data["test"])
In [41]:
# 2. Tokenize using chord_to_idx (drops unseen chords)
valid_chord_tokens = [
    [chord_to_idx[ch] for ch in seq if ch in chord_to_idx]
    for seq in valid_chord_seqs
]
test_chord_tokens = [
    [chord_to_idx[ch] for ch in seq if ch in chord_to_idx]
    for seq in test_chord_seqs
]

# Build chord‐level datasets & dataloaders
val_dataset_ch  = ChordSequenceDataset(valid_chord_tokens, seq_len=seq_len)
test_dataset_ch = ChordSequenceDataset(test_chord_tokens,  seq_len=seq_len)
val_loader_ch   = DataLoader(val_dataset_ch,  batch_size=batch_size)
test_loader_ch  = DataLoader(test_dataset_ch, batch_size=batch_size)

print(f"Validation chord samples: {len(val_dataset_ch)},  Test chord samples: {len(test_dataset_ch)}")
Validation chord samples: 1258,  Test chord samples: 1396
In [ ]:
# 3. Evaluation helper
def evaluate(model, loader, vocab_size, device="cuda" if torch.cuda.is_available() else "cpu"):
    model.eval()
    loss_fn = nn.CrossEntropyLoss(reduction="sum")
    total_loss = 0.0
    total_tokens = 0
    
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)  # (batch, seq_len, vocab_size)
            
            loss = loss_fn(
                logits.view(-1, vocab_size),
                yb.view(-1)
            )
            total_loss += loss.item()
            total_tokens += yb.numel()
    
    avg_loss = total_loss / total_tokens
    ppl = np.exp(avg_loss)
    return avg_loss, ppl

# 4. Run chord‐level evaluation
val_loss_ch, val_ppl_ch   = evaluate(model, val_loader_ch,  vocab_size)
test_loss_ch, test_ppl_ch = evaluate(model, test_loader_ch, vocab_size)

print(f"Validation (chords)  —  Loss: {val_loss_ch:.4f},  Perplexity: {val_ppl_ch:.2f}")
print(f"Test (chords)        —  Loss: {test_loss_ch:.4f},  Perplexity: {test_ppl_ch:.2f}")
Validation (chords)  —  Loss: 7.3067,  Perplexity: 1490.21
Test (chords)        —  Loss: 7.2839,  Perplexity: 1456.60

Interpretation (Chord Level)¶

  • Validation loss = 7.3067 (perplexity ≈ 1490.21)
  • Test loss = 7.2839 (perplexity ≈ 1456.60)

A chord‐level perplexity of ~1457–1490 indicates the model is effectively choosing among ~1,450 equally likely four‐voice chord tokens at each step. The small gap between validation and test perplexities suggests the model generalizes reasonably well at the chord level, though absolute perplexity remains high (likely because the chord vocabulary is large).

2. Voice‐Specific Pitch Statistics¶

Next, we inspect each individual voice (soprano, alto, tenor, bass) to see whether the model captures their marginal pitch distributions. For each voice:

  1. Flatten all pitches of that voice in the test set.
  2. Flatten all pitches of that voice in one generated sample (we’ll use “Generated A (random4)”).
  3. Compute mean & standard deviation for both.
  4. Plot histograms side by side (test vs generated).
In [ ]:
# Flatten test set pitches for each of the 4 voices
# test_chord_tokens is a list of chord‐index sequences; convert back to chord_tuples
test_voice_pitches = {i: [] for i in range(4)}  # 0=soprano, 1=alto, 2=tenor, 3=bass

for seq in test_chord_tokens:
    for chord_idx in seq:
        chord_tuple = idx_to_chord[chord_idx]  # (soprano, alto, tenor, bass)
        for voice_idx in range(4):
            test_voice_pitches[voice_idx].append(chord_tuple[voice_idx])

# Flatten generated sample A (“random4”) pitches for each voice
gen_chord_ids_A = gens["A_random4"]  # list of chord‐indices
gen_voice_pitches_A = {i: [] for i in range(4)}

for chord_idx in gen_chord_ids_A:
    chord_tuple = idx_to_chord[chord_idx]
    for voice_idx in range(4):
        gen_voice_pitches_A[voice_idx].append(chord_tuple[voice_idx])

# Compute mean & std for each voice (test vs generated A)
for v in range(4):
    mean_test = np.mean(test_voice_pitches[v])
    std_test  = np.std(test_voice_pitches[v])
    mean_gen  = np.mean(gen_voice_pitches_A[v])
    std_gen   = np.std(gen_voice_pitches_A[v])
    voice_name = ["Soprano","Alto","Tenor","Bass"][v]
    print(f"{voice_name}:")
    print(f"  Test   mean = {mean_test:.2f}, std = {std_test:.2f}")
    print(f"  Gen A  mean = {mean_gen:.2f}, std = {std_gen:.2f}\n")
Soprano:
  Test   mean = 63.11, std = 5.36
  Gen A  mean = 66.30, std = 4.60

Alto:
  Test   mean = 71.99, std = 4.58
  Gen A  mean = 75.64, std = 3.62

Tenor:
  Test   mean = 77.66, std = 4.59
  Gen A  mean = 81.00, std = 3.28

Bass:
  Test   mean = 83.13, std = 4.99
  Gen A  mean = 85.67, std = 2.97

Interpretation (Voice Marginals)¶

  • Soprano: Test (63.11 ± 5.36), Gen A (66.30 ± 4.60). The generated soprano is shifted ~3 semitones higher on average and slightly less variable.
  • Alto: Test (71.99 ± 4.58), Gen A (75.64 ± 3.62). The generated alto is ~3.65 semitones higher.
  • Tenor: Test (77.66 ± 4.59), Gen A (81.00 ± 3.28). The generated tenor is ~3.34 semitones higher.
  • Bass: Test (83.13 ± 4.99), Gen A (85.67 ± 2.97). The generated bass is ~2.54 semitones higher.

All four voices in the generated sample are biased toward higher pitches compared to the test set, and their variances are slightly reduced. This indicates the model has learned pitch ranges but drifts upward in all voices.

In [ ]:
# Plot histograms (Test vs Generated A) for each voice
voice_labels = ["Soprano", "Alto", "Tenor", "Bass"]
fig, axes = plt.subplots(4, 2, figsize=(12, 12), sharey=False)

for v in range(4):
    # Test histogram (left column)
    axes[v, 0].hist(
        test_voice_pitches[v],
        bins=range(min(test_voice_pitches[v]), max(test_voice_pitches[v]) + 2),
        color='blue',
        alpha=0.7
    )
    axes[v, 0].set_title(f"{voice_labels[v]} Test Pitch Dist.")
    axes[v, 0].set_xlabel("MIDI Pitch")
    axes[v, 0].set_ylabel("Count")
    axes[v, 0].grid(True, linestyle='--', alpha=0.5)

    # Generated A histogram (right column)
    axes[v, 1].hist(
        gen_voice_pitches_A[v],
        bins=range(min(gen_voice_pitches_A[v]), max(gen_voice_pitches_A[v]) + 2),
        color='orange',
        alpha=0.7
    )
    axes[v, 1].set_title(f"{voice_labels[v]} Gen A Pitch Dist.")
    axes[v, 1].set_xlabel("MIDI Pitch")
    axes[v, 1].set_ylabel("Count")
    axes[v, 1].grid(True, linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()
No description has been provided for this image

3. Voice‐Transition (Bigram) Analysis¶

To see if the model learns realistic voice‐step transitions, we compare bigram frequencies (consecutive‐note pairs) in the test set vs. the generated sample for each voice. We will:

  1. Build bigram counts for each voice in test set: count all pairs (pitch_t, pitch_{t+1}).
  2. Build bigram counts for each voice in one generated sample (A).
  3. Convert these counts to conditional distributions P(next_pitch | current_pitch) and measure KL divergence from test to generated, per voice.
In [ ]:
def compute_bigram_probs(pitch_sequence):
    """Given a list of pitches, returns dict: { current_pitch: { next_pitch: prob } }."""
    counts = {}
    for (p1, p2) in zip(pitch_sequence[:-1], pitch_sequence[1:]):
        counts.setdefault(p1, collections.Counter())
        counts[p1][p2] += 1
    # Normalize to probabilities
    bigram_probs = {}
    for p1, counter in counts.items():
        total = sum(counter.values())
        bigram_probs[p1] = {p2: count / total for p2, count in counter.items()}
    return bigram_probs

def kl_divergence(p_dist, q_dist):
    """
    Compute KL(P || Q) where P and Q are dicts of { symbol: prob }.
    Missing symbols in Q receive a small epsilon probability.
    """
    epsilon = 1e-8
    kl = 0.0
    for symbol, p_val in p_dist.items():
        q_val = q_dist.get(symbol, epsilon)
        kl += p_val * math.log(p_val / q_val)
    return kl

# Build bigram distributions for each voice in test set
test_bigram = {}
for v in range(4):
    test_bigram[v] = compute_bigram_probs(test_voice_pitches[v])

# Build bigram distributions for each voice in generated A
gen_bigram_A = {}
for v in range(4):
    gen_bigram_A[v] = compute_bigram_probs(gen_voice_pitches_A[v])

# Compute average KL divergence across all current_pitch contexts, per voice
kl_results = {}
for v in range(4):
    kl_sum, count = 0.0, 0
    for p1, p_dist in test_bigram[v].items():
        q_dist = gen_bigram_A[v].get(p1, {})
        kl_sum += kl_divergence(p_dist, q_dist)
        count += 1
    kl_results[v] = kl_sum / count if count > 0 else float('nan')

# Print KL divergence for each voice
for v in range(4):
    print(f"{voice_labels[v]} KL divergence (Test || Gen A) = {kl_results[v]:.4f}")
Soprano KL divergence (Test || Gen A) = 13.9683
Alto KL divergence (Test || Gen A) = 12.3597
Tenor KL divergence (Test || Gen A) = 13.4401
Bass KL divergence (Test || Gen A) = 13.4021

Interpretation (Voice Bigram KL)¶

  • Each voice’s KL divergence is around 12–14. This large value means the generated voice‐leading transitions differ substantially from the test distribution. In other words, the model’s step‐to‐step pitch choices for each voice don’t closely match the test chorales.

4. Chord‐Transition (Bigram) Analysis¶

Finally, we check whether the model’s predicted chord transitions (4‐voice bigrams) match those in the test set:

  1. Build test‐set bigram counts on chord‐indices.
  2. Build generated sample A bigram counts on chord‐indices.
  3. Compute KL divergence of chord‐transition distributions.
In [ ]:
# Build bigram distributions on chord‐indices (test set)
test_chords_flat = [idx for seq in test_chord_tokens for idx in seq]
test_chord_bigram = compute_bigram_probs(test_chords_flat)

# Build bigram distributions on chord‐indices (generated A)
gen_chord_bigram_A = compute_bigram_probs(gen_chord_ids_A)

# Compute KL divergence for each chord context
kl_sum, count = 0.0, 0
for c1, p_dist in test_chord_bigram.items():
    q_dist = gen_chord_bigram_A.get(c1, {})
    kl_sum += kl_divergence(p_dist, q_dist)
    count += 1
kl_chord = kl_sum / count if count > 0 else float('nan')

print(f"Chord‐level KL divergence (Test || Gen A) = {kl_chord:.4f}")
Chord‐level KL divergence (Test || Gen A) = 17.6501

Interpretation (Chord Bigram KL Divergence)¶

  • A chord‐level KL of ~17.65 is very large, indicating the generated chord‐to‐chord transitions deviate greatly from the test distribution. The model is not capturing four‐voice harmonic progressions as faithfully as expected.

Overall Evaluation Summary¶

  1. Chord‐Level Perplexity:

    • Validation ≈ 1490.21, Test ≈ 1456.60

    This high perplexity reflects the large chord vocabulary (~1,500 distinct chords) and shows the model still struggles to narrow down its predictions reliably.

  2. Voice Marginals:

    • Generated voices (Gen A) skew higher in pitch (≈ +2–4 semitones) relative to the test distributions, with slightly lower variance.
  3. Voice Bigram KL Divergences (Test‖Gen A):

    • Soprano ≈ 13.97
    • Alto ≈ 12.36
    • Tenor ≈ 13.44
    • Bass ≈ 13.40

    These large KL values indicate the model’s stepwise pitch transitions do not match the test set well for any voice.

  4. Chord Bigram KL Divergence:

    • ≈ 17.65

    The four‐voice harmonic transitions in generated chords differ significantly from those in the held‐out chorales.

Taken together, the evaluation shows that, although the model has learned to generate plausible chord shapes (it still produces four‐voice chords), its distributions—both marginal and sequential—drift noticeably from the test data. There is ample room for improvement (e.g., more training data, larger/deeper architectures, attention to balancing pitch ranges).

Task 2 (NEW)¶

Overview¶

We extended the functionality of the Piano Genie project by Google to generate music based on words typed with the full QWERTY keyboard. This allows us to also feed in and output velocity into the model, where the project did not do so before.

The original Github for the project can be found here: https://github.com/chrisdonahue/piano-genie

Inspiration¶

Monkeytype is a popular typing speed website that some of our members frequent. We were curious to see if we could generate music dynamically based on what the user was typing https://monkeytype.com/

We saw the Piano Genie demo in class, and saw an opportunity to extend it to explore this idea. The original demo can be found here: https://www.i-am.ai/piano-genie.html

Usage¶

Use our utility HTML script to open a window where you can type out a paragraph. The HTML script will encode your typing as a combination of keys clicked, timestamp, and your word per minute.

The utility script will output these stats to a .csv file that can be passed into the model to generate new music.

QWERTY Input¶

We experimented with a few different ideas for how the QWERTY keyboard could augment what music is played.

Velocity per row¶

We tested assigning each row of a keyboard to different velocities to be encoded during typing. For instance, the row with the keys "QWERTYUIOP" would be assigned a higher velocity value than the row with the keys "ZXCVBNM".

We kept 8 note bins (forcing keys past the 8th in the row counting left to right to represent the 8th bin) during our initial experiments.

Notes¶

We left some of the original documentation in to assist with setting context and if the project reviewers are interested in what work was done before. All original documentation in pure markdown blocks have been surrounded by codeblocks, which can be removed for your viewing convenience. Additionally, inline markdown is most likely partially or fully from the original repository. We mark documentation written fully by us with the tag (NEW). Also credit to copilot for assistance with the project.

Performance Encoding (NEW)¶

We created a new system in order to encode our user's musical performance on a QWERTY keyboard into a format interpretable by a machine learning model.

User UI¶

No description has been provided for this image

The new UI for adding text is a simple text input box, with the ability to click in and type text. The user can then download their performance as a CSV file using the Download CSV button.

CSV File Format¶

The first row of the CSV file is a header, containing the following column names:

  • Key: This column stores the key pressed by the user on the QWERTY keyboard.
  • Seconds: This column stores the time elapsed in seconds since the very first key press. The time is recorded with three decimal places.
  • WPM: This column stores the words per minute calculated at the moment the key was pressed. The WPM is recorded with one decimal place.

Each subsequent row after the first in the CSV file represents a single key press event. The values in these rows correspond to the columns defined in the header: the specific key pressed, the time of the key press, and the calculated words per minute at that moment.

For example, the row I,3.496,10.3 indicates that the key "I" was pressed at 3.496 seconds after the initial key press, and at that time, the user's typing speed was 10.3 words per minute.

Example CSV of Short Performance¶

In the below performance, the phrase I am going to pass was typed into the terminal.

Key,Seconds,WPM
Shift,0.000,Infinity
Shift,3.373,7.1
I,3.496,10.3
 ,3.622,13.3
w,3.697,16.2
Backspace,3.957,18.2
a,4.023,20.9
m,4.121,23.3
 ,4.229,25.5
g,4.298,27.9
o,4.399,30.0
i,4.538,31.7
n,4.591,34.0
g,4.665,36.0
 ,4.731,38.0
t,4.798,40.0
o,4.837,42.2
 ,4.923,43.9
p,5.040,45.2
a,5.102,47.0
s,5.246,48.0
s,5.367,49.2

Full Script For User Input¶

<!DOCTYPE html>
<html>
  <body>
    <textarea id="t" rows="10" cols="60" placeholder="Start typing…"></textarea>
    <br />
    <button id="download">Download CSV</button>

    <script>
      let startTime = null;
      let charCount = 0;
      // Header: Key, Seconds since first press, WPM at that moment
      const rows = [["Key", "Seconds", "WPM"]];

      const log = (key, secs, wpm) => {
        rows.push([key, secs.toFixed(3), wpm.toFixed(1)]);
      };

      document.getElementById("t").addEventListener("keydown", (e) => {
        if (startTime === null) {
          startTime = performance.now();
        }
        charCount++;
        const now = performance.now();
        const elapsedMs = now - startTime;
        const elapsedSecs = elapsedMs / 1000;
        const elapsedMins = elapsedMs / 60000;
        // WPM = (chars ÷ 5) ÷ elapsedMinutes
        const wpm = charCount / 5 / elapsedMins;
        log(e.key, elapsedSecs, wpm);
      });

      document.getElementById("download").addEventListener("click", () => {
        const csvContent = rows.map((r) => r.join(",")).join("\n");
        const blob = new Blob([csvContent], { type: "text/csv" });
        const url = URL.createObjectURL(blob);
        const a = document.createElement("a");
        a.href = url;
        a.download = "typing_wpm_timestamps.csv";
        document.body.appendChild(a);
        a.click();
        document.body.removeChild(a);
        URL.revokeObjectURL(url);
      });
    </script>
  </body>
</html>

Initial tests:¶

During our initial tests on whether or not we could get the pipeline to run, we converted our performance into 8 bins using the following script:

In [ ]:
# version 1
def letter_to_button_keyboard(letter):
    # Map letters on the keyboard to button indices, top row, middle row, bottom row
    top = "qwertyuiop"
    middle = "asdfghjkl"
    bottom = "zxcvbnm"
    if letter in top:
        return min(top.index(letter), 8), 40
    elif letter in middle:
        return min(middle.index(letter), 8), 80
    elif letter in bottom:
        return min(bottom.index(letter), 8), 120
    else:
        return 0, 0

After working through the missing dependencies and various setup quirks, we were able to generate this with our input script and new music:

In [ ]:
# Convert chord-indices → write a four-voice MIDI


# Convert original 4-voice (first 64 chords) and each generated version to WAV
fs = FluidSynth("FluidR3_GM.sf2")
fs.midi_to_audio("output_1748243578.2774305.mid", "output_1748243578.2774305.wav")
fs.midi_to_audio("output_1748243604.4184122.mid",   "output_1748243604.4184122.wav")
fs.midi_to_audio("output_1748243854.1907494.mid",   "output_1748243854.1907494.wav")
fs.midi_to_audio("output_1748243963.1448758.mid",   "output_1748243963.1448758.wav")

# Play back and display audio in notebook
# print("🎹 Attempt 1:")
# display(Audio("output_1748243578.wav"))

# print("🎹 Attempt 2:")
# display(Audio("output_1748243604.wav"))

# print("🎹 Attempt 3:")
# display(Audio("output_1748243854.wav"))

# print("🎹 Attempt 4:")
# display(Audio("output_1748243963.wav"))

Pipeline (NEW):¶

Below are our modifications to the original pipeline. The pipeline has been modified in order to use the following inputs and outputs:

  • Input: our new CSV file format representing what keys users have clicked while typing
  • Output: A MIDI file that converts general music "bins" (where the bins are the keys the users clicked) to proper notes. In other words, converts our user's keyboard input into a musical piece.

We have modified the pipeline to accept 26 bins, and also output velocity.

## Primer on Piano Genie

The generative model we will train is called [Piano Genie](https://magenta.tensorflow.org/pianogenie) (Donahue et al. 2019). Piano Genie is a system which maps amateur improvisations on a miniature 8-button keyboard ([video](https://www.youtube.com/watch?v=YRb0XAnUpIk), [demo](https://piano-genie.glitch.me)) into realistic performances on a full 88-key piano.

To achieve this, Piano Genie adopts an _autoencoder_ approach. First, an _encoder_ maps professional piano performances into this 8-button space. Then, a _decoder_ attempts to reconstruct the original piano performance from the 8-button version. The entire system is trained end-to-end to minimize the decoder's reconstruction error. At performance time, we replace the encoder with a user improvising on an 8-button controller, and use the pre-trained decoder to generate a corresponding piano performance.

<center><img src="https://raw.githubusercontent.com/chrisdonahue/music-cocreation-tutorial/main/part-1-py-training/figures/overview.png" width=600px/></center>

At a low-level, both the encoder and the decoder for Piano Genie are lightweight recurrent neural networks, which are suitable for real-time performance even on mobile CPUs. The discrete bottleneck is achieved using a technique called _integer-quantized autoencoding_ (IQAE), which was also proposed in the Piano Genie paper.
In [16]:
#@title **(Step 1)** Parse MIDI piano performances into simple lists of notes

# @markdown *Note*: Check this box to rebuild the dataset from scratch.
REBUILD_DATASET = False  # @param{type:"boolean"}

# @markdown To train Piano Genie, we will use a dataset of professional piano performances called [MAESTRO](https://magenta.tensorflow.org/datasets/maestro) (Hawthorne et al. 2019).
# @markdown Each performance in this dataset was captured by a Disklavier, a computerized piano which can record human performances in MIDI format, i.e., as timestamped sequences of notes.

PIANO_LOWEST_KEY_MIDI_PITCH = 21
PIANO_NUM_KEYS = 88

import gzip
import json
from collections import defaultdict

from tqdm.notebook import tqdm


def download_and_parse_maestro():
    # Install pretty_midi
    !pip install pretty_midi
    import pretty_midi

    # Download MAESTRO dataset (Hawthorne+ 2018)
    # !wget -nc https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
    !unzip maestro-v2.0.0-midi.zip

    # Parse MAESTRO dataset
    dataset = defaultdict(list)
    with open("maestro-v2.0.0/maestro-v2.0.0.json", "r") as f:
        for attrs in tqdm(json.load(f)):
            split = attrs["split"]
            midi = pretty_midi.PrettyMIDI("maestro-v2.0.0/" + attrs["midi_filename"])
            assert len(midi.instruments) == 1
            # @markdown Formally, a piano performance is a sequence of notes: $\mathbf{x} = (x_1, \ldots, x_N)$, where each $x_i = (t_i, d_i, k_i, v_i)$, signifying:
            notes = [
                (
                    # @markdown 1. (When the key was pressed) An _onset_ time $t_i \in \mathbb{T}$, where $\mathbb{T} = \{ t \in \mathbb{R} \mid 0 \leq t \leq T \}$
                    float(n.start),
                    # @markdown 2. (How long the key was held) A _duration_ $d_i \in \mathbb{R}_{>0}$
                    float(n.end) - float(n.start),
                    # @markdown 3. (Which key was pressed) A _key_ index $k_i \in \mathbb{K}$, where $\mathbb{K} = \{\text{A0}, \ldots, \text{C8}\}$ and $|\mathbb{K}| = 88$
                    int(n.pitch - PIANO_LOWEST_KEY_MIDI_PITCH),
                    # @markdown 4. (How hard the key was pressed) A _velocity_ $v_i \in \mathbb{V}$, where $\mathbb{V} = \{1, \ldots, 127\}$
                    int(n.velocity),
                )
                for n in midi.instruments[0].notes
            ]

            # This list is in sorted order of onset time, i.e., $t_{i-1} \leq t_i ~\forall~i \in \{2, \ldots, N\}$.
            notes = sorted(notes, key=lambda n: (n[0], n[2]))
            assert all(
                [
                    all(
                        [
                            # Start times should be non-negative
                            n[0] >= 0,
                            # Note durations should be strictly positive, i.e., $d_i > 0$
                            n[1] > 0,
                            # Key index should be in range of the piano
                            0 <= n[2] and n[2] < PIANO_NUM_KEYS,
                            # Velocity should be valid
                            1 <= n[3] and n[3] < 128,
                        ]
                    )
                    for n in notes
                ]
            )
            dataset[split].append(notes)

        return dataset


if REBUILD_DATASET:
    DATASET = download_and_parse_maestro()
    with gzip.open("maestro-v2.0.0-simple.json.gz", "w") as f:
        f.write(json.dumps(DATASET).encode("utf-8"))
else:
    # !wget -nc https://github.com/chrisdonahue/music-cocreation-tutorial/raw/main/part-1-py-training/data/maestro-v2.0.0-simple.json.gz
    with gzip.open("maestro-v2.0.0-simple.json.gz", "rb") as f:
        DATASET = json.load(f)

print([(s, len(DATASET[s])) for s in ["train", "validation", "test"]])
[('train', 967), ('validation', 137), ('test', 178)]
In [17]:
# Step 1: Inspect the data structure and content
print("Dataset splits and number of performances:")
for split, performances in DATASET.items():
    print(f"- {split}: {len(performances)} performances")

# Inspect the first performance in the training set
if DATASET['train']:
    first_performance = DATASET['train'][0]
    print("\nStructure of the first performance (first 5 notes):")
    print(first_performance[:5])
    print("\nData types of the first note:")
    if first_performance:
        first_note = first_performance[0]
        for i, value in enumerate(first_note):
            print(f"- Element {i}: {type(value)}")
Dataset splits and number of performances:
- train: 967 performances
- validation: 137 performances
- test: 178 performances

Structure of the first performance (first 5 notes):
[[0.9830729166666666, 0.8268229166666666, 46, 52], [1.7838541666666665, 0.12239583333333348, 51, 67], [2.1471354166666665, 1.453125, 57, 65], [2.153645833333333, 2.170572916666667, 50, 45], [2.1783854166666665, 1.0651041666666665, 40, 39]]

Data types of the first note:
- Element 0: <class 'float'>
- Element 1: <class 'float'>
- Element 2: <class 'int'>
- Element 3: <class 'int'>
In [18]:
# Install necessary libraries
%pip install pandas matplotlib seaborn numpy
Collecting pandas
  Downloading pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl.metadata (89 kB)
Requirement already satisfied: matplotlib in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (3.10.3)
Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Requirement already satisfied: numpy in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (2.2.6)
Requirement already satisfied: python-dateutil>=2.8.2 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from pandas) (2.9.0.post0)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Requirement already satisfied: contourpy>=1.0.1 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (1.3.2)
Requirement already satisfied: cycler>=0.10 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (4.58.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (1.4.8)
Requirement already satisfied: packaging>=20.0 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (25.0)
Requirement already satisfied: pillow>=8 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (11.2.1)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from matplotlib) (3.2.3)
Requirement already satisfied: six>=1.5 in /Users/andysmithwick/anaconda3/envs/cse253/lib/python3.13/site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)
Downloading pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl (11.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.3/11.3 MB 15.2 MB/s eta 0:00:00 0:00:01
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Using cached tzdata-2025.2-py2.py3-none-any.whl (347 kB)
Installing collected packages: pytz, tzdata, pandas, seaborn
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/4 [seaborn]m3/4 [seaborn]
Successfully installed pandas-2.2.3 pytz-2025.2 seaborn-0.13.2 tzdata-2025.2
Note: you may need to restart the kernel to use updated packages.
In [19]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Step 2: Analyze note properties

# Combine all notes from all splits into a single list for analysis
all_notes = []
for split, performances in DATASET.items():
    for performance in performances:
        all_notes.extend(performance)

# Create a pandas DataFrame for easier analysis
notes_df = pd.DataFrame(all_notes, columns=['onset_time', 'duration', 'key_index', 'velocity'])

# Distribution of key indices
plt.figure(figsize=(12, 6))
sns.histplot(notes_df['key_index'], bins=PIANO_NUM_KEYS, kde=False)
plt.title('Distribution of Piano Key Indices')
plt.xlabel('Key Index (0-87)')
plt.ylabel('Frequency')
plt.show()

# Distribution of velocities
plt.figure(figsize=(12, 6))
sns.histplot(notes_df['velocity'], bins=127, kde=False)
plt.title('Distribution of Velocities')
plt.xlabel('Velocity (1-127)')
plt.ylabel('Frequency')
plt.show()

# Distribution of durations
plt.figure(figsize=(12, 6))
sns.histplot(notes_df['duration'], bins=50, kde=True)
plt.title('Distribution of Note Durations')
plt.xlabel('Duration (seconds)')
plt.ylabel('Frequency')
plt.xlim(0, 5) # Limit x-axis for better visualization of common durations
plt.show()

# Distribution of inter-onset intervals
# Calculate inter-onset intervals for each performance
all_iois = []
for split, performances in DATASET.items():
    for performance in performances:
        onsets = [n[0] for n in performance]
        iois = np.diff(onsets)
        all_iois.extend(iois)

plt.figure(figsize=(12, 6))
sns.histplot(all_iois, bins=50, kde=True)
plt.title('Distribution of Inter-Onset Intervals')
plt.xlabel('Inter-Onset Interval (seconds)')
plt.ylabel('Frequency')
plt.xlim(0, 2) # Limit x-axis for better visualization of common intervals
plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [20]:
# Step 3: Analyze performance properties

# Number of notes per performance
num_notes_per_performance = [len(p) for split, performances in DATASET.items() for p in performances]
plt.figure(figsize=(12, 6))
sns.histplot(num_notes_per_performance, bins=50, kde=True)
plt.title('Distribution of Number of Notes per Performance')
plt.xlabel('Number of Notes')
plt.ylabel('Frequency')
plt.show()

# Total duration of performances
total_duration_per_performance = [p[-1][0] - p[0][0] if p else 0 for split, performances in DATASET.items() for p in performances]
plt.figure(figsize=(12, 6))
sns.histplot(total_duration_per_performance, bins=50, kde=True)
plt.title('Distribution of Total Performance Duration')
plt.xlabel('Duration (seconds)')
plt.ylabel('Frequency')
plt.show()
No description has been provided for this image
No description has been provided for this image

Reintroducing Velocity to the autoencoder (NEW)¶

As noted in the original documentation, the original version of Piano Genie did not utilize velocity to make its predictions. Since we use the entire QWERTY keyboard for our project, we saw an opportunity to test if the model performance would differ by adding velocity to the autoencoder.

In the PianoDecoder class, we needed to reintroduce the velocity parameter initially dropped from the inputs, as shown below:

inputs = [
            F.one_hot(k_m1, PIANO_NUM_KEYS + 1),
            t.unsqueeze(dim=2),
            b.unsqueeze(dim=2),
            v.unsqueeze(dim=2),
        ]

We also had to do a similar thing for the encoder, changing function signatures:

def forward(self, k, t, v):
        inputs = [
            F.one_hot(k, PIANO_NUM_KEYS),
            t.unsqueeze(dim=2),
            v.unsqueeze(dim=2),
        ]

Location in original docs mentioning lack of velocity:¶

we anticipate that it will be frustrating for users if the model predicts dynamics on their behalf, so we remove velocity terms  𝐯 :
In [21]:
# @title **(Step 2)** Define Piano Genie autoencoder

# @markdown Our intended interaction for Piano Genie is to have users perform on a miniature 8-button keyboard and automatically map each of their button presses to a key on a piano.
# @markdown Similarly to our formalization of piano performances, we will formalize a "button performance" as a sequence of "notes", where piano keys $k_i$ are replaced with buttons $b_i$, and we remove velocity since our button controller is not velocity sensitive.
# @markdown So a button performance $\mathbf{c}$ is:

# @markdown - $\mathbf{c} = (c_1, \ldots, c_N)$, where $c_i = (t_i, d_i, b_i \in \mathbb{B})$, i.e., (onsets, durations, buttons), and $\mathbb{B} = \{ \color{#EE2B29}\blacksquare, \color{#ff9800}\blacksquare, \color{#ffff00}\blacksquare, \color{#c6ff00}\blacksquare, \color{#00e5ff}\blacksquare, \color{#2979ff}\blacksquare, \color{#651fff}\blacksquare, \color{#d500f9}\blacksquare \}$

# @markdown And a corresponding piano performance is:

# @markdown - $\mathbf{x} = (x_1, \ldots, x_N)$, where $x_i = (t_i, d_i, k_i, v_i)$, i.e., (onsets, durations, keys, velocities)

# @markdown To map button performances into piano performances, we will train a generative model $P(\mathbf{x} \mid \mathbf{c})$.
# @markdown In practice, we will factorize this joint distribution over note sequences $\mathbf{x}$ into the product of conditional probabilities of individual notes: $P(\mathbf{x} \mid \mathbf{c}) = \prod_{i=1}^{N} P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$.

# @markdown Hence, our **overall goal is to learn** $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$,
# @markdown which we will **approximate by modeling**:

# @markdown <center>$P(k_i \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$.</center>

# @markdown We arrived at this approximation by working through constraints imposed by the interaction (details at the end).

import torch
import torch.nn as nn
import torch.nn.functional as F

# @markdown #### **Decoder**

# @markdown <center><img src="https://raw.githubusercontent.com/chrisdonahue/music-cocreation-tutorial/main/part-1-py-training/figures/decoder.png" width=600px/></center>
# @markdown <center><b>Piano Genie decoder processing $N=4$ notes</b></center>

# @markdown The approximation $P(k_i \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$ constitutes the decoder of Piano Genie, which we will parameterize using an RNN.
# @markdown This is the portion of the model that users will interact with.
# @markdown To achieve our intended real-time interaction, we will compute and sample from this RNN at the instant the user presses a button, passing as input the key from the previous timestep, the current time, the button the user pressed, and a vector which summarizes the ongoing history.

# @markdown Formally, the decoder is a function:
# @markdown $D_{\theta}: k_{i-1}, t_i, b_i, \mathbf{h}_{i-1} \mapsto \mathbf{\hat{k}}_i, \mathbf{h}_i$, where:

# @markdown - $k_0$ is a special start-of-sequence token $<\text{S}>$

# @markdown - $\mathbf{h}_i$ is a vector summarizing timesteps $1, \ldots, i$

# @markdown - $\mathbf{h}_0$ is some initial value (zeros) for that vector

# @markdown - $\mathbf{\hat{k}}_i \in \mathbb{R}^{88}$ are the output logits for timestep $i$

SOS = PIANO_NUM_KEYS

class PianoGenieDecoder(nn.Module):
    def __init__(self, rnn_dim=128, rnn_num_layers=2):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        #change this to 4
        self.input = nn.Linear(PIANO_NUM_KEYS + 4, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            batch_first=True,
            bidirectional=False,
        )
        self.output = nn.Linear(rnn_dim, 88)

    def init_hidden(self, batch_size, device=None):
        h = torch.zeros(self.rnn_num_layers, batch_size, self.rnn_dim, device=device)
        c = torch.zeros(self.rnn_num_layers, batch_size, self.rnn_dim, device=device)
        return (h, c)

    def forward(self, k, t, b, v, h_0=None):
        # Prepend <S> token to shift k_i to k_{i-1}
        k_m1 = torch.cat([torch.full_like(k[:, :1], SOS), k[:, :-1]], dim=1)

        # Encode input
        inputs = [
            F.one_hot(k_m1, PIANO_NUM_KEYS + 1),
            t.unsqueeze(dim=2),
            b.unsqueeze(dim=2),
            v.unsqueeze(dim=2),
        ]
        x = torch.cat(inputs, dim=2)

        # Project encoded inputs
        x = self.input(x)

        # Run RNN
        if h_0 is None:
            h_0 = self.init_hidden(k.shape[0], device=k.device)
        x, h_N = self.lstm(x, h_0)

        # Compute logits
        hat_k = self.output(x)

        return hat_k, h_N


# @markdown #### **Encoder**

# @markdown <center><img src="https://i.imgur.com/P3bQFsC.png" width=600px/></center>
# @markdown <center><b>Piano Genie encoder processing $N=4$ notes</b></center>

# @markdown Because we lack examples of human button performances, we use an encoder to automatically learn to map piano performances into synthetic button performances.
# @markdown The encoder takes as input a sequence of keys and onset times and produces an equal-length sequence of buttons.
# @markdown Formally, the encoder is a function: $E_{\varphi} : \mathbf{k}, \mathbf{t} \mapsto \mathbf{b}$.

# @markdown Note the conceptual difference between the decoder and the encoder: the decoder process one sequence item at a time, while the encoder maps an entire input sequence to an output sequence.
# @markdown This is because the decoder (which we will use during inference) needs to process information as it becomes available in real time, whereas the encoder (which we only use during training) can observe the entire piano sequence before translating it into buttons.
# @markdown Despite this conceptual difference, in practice the encoder is also an RNN (though a bidirectional one) under the hood.

class PianoGenieEncoder(nn.Module):
    def __init__(self, rnn_dim=128, rnn_num_layers=2):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(PIANO_NUM_KEYS + 2, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            batch_first=True,
            bidirectional=True,
        )
        self.output = nn.Linear(rnn_dim * 2, 1)

    def forward(self, k, t, v):
        inputs = [
            F.one_hot(k, PIANO_NUM_KEYS),
            t.unsqueeze(dim=2),
            v.unsqueeze(dim=2),
        ]
        x = self.input(torch.cat(inputs, dim=2))
        # NOTE: PyTorch uses zeros automatically if h is None
        x, _ = self.lstm(x, None)
        x = self.output(x)
        return x[:, :, 0]


# @markdown #### **Quantizing encoder output to discrete buttons**

# @markdown <center><img src="https://raw.githubusercontent.com/chrisdonahue/music-cocreation-tutorial/main/part-1-py-training/figures/quantization.png" width=600px/></center>
# @markdown <center><b>Quantizing continuous encoder output (grey line) to eight discrete values (colorful line segments)</b></center>

# @markdown You may have noticed in the code that the encoder outputs a real-valued scalar (let's call it $e_i \in \mathbb{R}$) at each timestep, but our goal is to output one of eight discrete buttons, i.e., $b_i \in \mathbb{B}$.
# @markdown To achieve this, we will quantize this real-valued scalar as the centroid of the nearest of eight bins between $[-1, 1]$ (see figure above):

# @markdown <center>$b_i = 2 \cdot \frac{\tilde{b}_i - 1}{B - 1} - 1$, where $\tilde{b}_i = \text{round} \left( 1 + (B - 1) \cdot \min \left( \max \left( \frac{e_i  + 1}{2}, 0 \right), 1 \right) \right)$</center>

class IntegerQuantizer(nn.Module):
    def __init__(self, num_bins):
        super().__init__()
        self.num_bins = num_bins

    def real_to_discrete(self, x, eps=1e-6):
        x = (x + 1) / 2
        x = torch.clamp(x, 0, 1)
        x *= self.num_bins - 1
        x = (torch.round(x) + eps).long()
        return x

    def discrete_to_real(self, x):
        x = x.float()
        x /= self.num_bins - 1
        x = (x * 2) - 1
        return x

    def forward(self, x):
        # Quantize and compute delta (used for straight-through estimator)
        with torch.no_grad():
            x_disc = self.real_to_discrete(x)
            x_quant = self.discrete_to_real(x_disc)
            x_quant_delta = x_quant - x

        # @markdown In the backwards pass, we will use the straight-through estimator (Bengio et al. 2013), i.e., pretend that this discretization did not happen when computing gradients.
        # Quantize w/ straight-through estimator
        x = x + x_quant_delta

        return x


# @markdown #### **Defining the autoencoder**

# @markdown Finally, the Piano Genie autoencoder is simply the composition of the encoder, quantizer, and decoder.

class PianoGenieAutoencoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.enc = PianoGenieEncoder(
            rnn_dim=cfg["model_rnn_dim"],
            rnn_num_layers=cfg["model_rnn_num_layers"],
        )
        self.quant = IntegerQuantizer(cfg["num_buttons"])
        self.dec = PianoGenieDecoder(
            rnn_dim=cfg["model_rnn_dim"],
            rnn_num_layers=cfg["model_rnn_num_layers"],
        )

    def forward(self, k, t, v):
        e = self.enc(k, t, v)
        b = self.quant(e)
        hat_k, _ = self.dec(k, t, b, v)
        return hat_k, e


# @markdown #### **Approximating $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$**

# @markdown This section walks through how we designed an approximation to $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$ which would be appropriate for our intended interaction. You probably don't need to understand this, but some may find it helpful as an illustration of how to design a generative model around constraints imposed by interaction.

# @markdown First, we expand the terms, treating the onsets $\mathbf{t}$ and durations $\mathbf{d}$ as part of the button performance $\mathbf{c}$:

# @markdown <center>$P(x_i \mid \mathbf{x}_{< i}, \mathbf{c}) = P(k_i, v_i \mid \mathbf{k}_{<i}, \mathbf{v}_{<i}, \mathbf{t}, \mathbf{d}, \mathbf{b})$</center>

# @markdown Because we want this interaction to be real-time, we must remove any information that might not be available at time $t_i$ (the moment the user presses a button), which includes future onsets $\mathbf{t}_{>i}$, future buttons $\mathbf{b}_{>i}$, and all durations $\mathbf{d}$, since notes can be held indefinitely:

# @markdown <center>$\approx P(k_i, v_i \mid \mathbf{k}_{<i}, \mathbf{v}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$</center>

# @markdown Finally, we anticipate that it will be frustrating for users if the model predicts dynamics on their behalf, so we remove velocity terms $\mathbf{v}$:

# @markdown <center>$\approx P(k_i, \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$</center>

(STEP 3) Modifying training pipeline (NEW)¶

After modifying the autoencoder, we needed to figure out a way in order to fix the function calls in the training pipeline so the encoder could properly receive the velocity values, and that we could properly interpret the new output.

We needed to add a new batch_v parameter. batch_v is a list that stores the velocity values for each note in a minibatch of piano performances. n[3] represents the velocity value in a note, which can be derived by looking at how the download_and_parse_maestro function structures the notes list.

In context, the changes described are below.

# Key features
        batch_k.append([n[2] for n in subsample])
        batch_v.append([n[3] for n in subsample])

        # Onset features
        # NOTE: For stability, we pass delta time to Piano Genie instead of time.
        t = np.diff([n[0] for n in subsample])
        t = np.concatenate([[1e8], t])
        t = np.clip(t, 0, CFG["data_delta_time_max"])
        batch_t.append(t)

    return (torch.tensor(batch_k).long(), torch.tensor(batch_t).float(), torch.tensor(batch_v).float())
In [23]:
'@title **(Step 3)** Train Piano Genie'

# @markdown *Note*: Check this box to log training curves to [Weights & Biases](https://wandb.ai/) (which will prompt you to log in).
USE_WANDB = False  # @param{type:"boolean"}

# @markdown Now that we've defined the autoencoder, we need to train it.
# @markdown We will train the entire autoencoder end-to-end to minimize the reconstruction loss of the decoder.

# @markdown <center>$\mathcal{L}_{\text{recons}} = \frac{1}{N} \sum_{i=1}^{N} \text{CrossEntropy}(\text{Softmax}(\mathbf{\hat{k}}_i), k_i)$</center>

# @markdown This loss alone does not encourage the encoder to produce button sequences with any particular structure, so the behavior of the decoder will likely be fairly unpredictable at interaction time.
# @markdown We think it might be intuitive to users if the decoder respected the _contour_ of their performance, i.e., if higher buttons produced higher notes and lower buttons produced lower notes.
# @markdown Hence, we include a loss term which encourages the encoder to produces button sequences which align with the contour of the piano key sequences.

# @markdown <center>$\mathcal{L}_{\text{contour}} = \frac{1}{N - 1} \sum_{i=2}^{N} \max (0, 1 - (k_i - k_{i-1}) \cdot (e_i - e_{i-1}))^2$</center>

# @markdown Finally, we find empirically that the encoder often outputs values outside of the $[-1, 1]$ range used for discretization.
# @markdown Hence, we add a loss term which explicitly encourages this behavior

# @markdown <center>$\mathcal{L}_{\text{margin}} = \frac{1}{N} \sum_{i=1}^{N} \max(0, |e_i| - 1)^2$</center>

# @markdown Thus, our final loss function is:
# @markdown <center>$\mathcal{L} = \mathcal{L}_{\text{recons}} + \mathcal{L}_{\text{contour}} + \mathcal{L}_{\text{margin}}$</center>


CFG = {
    "seed": 0,
    # Number of buttons in interface
    "num_buttons": 26,
    # Onset delta times will be clipped to this maximum
    "data_delta_time_max": 1.0,
    # Max time stretch for data augmentation (+- 5%)
    "data_augment_time_stretch_max": 0.05,
    # Max transposition for data augmentation (+- tritone)
    "data_augment_transpose_max": 6,
    # RNN dimensionality
    "model_rnn_dim": 128,
    # RNN num layers
    "model_rnn_num_layers": 2,
    # Training hyperparameters
    "batch_size": 32,
    "seq_len": 128,
    "lr": 3e-4,
    "loss_margin_multiplier": 1.0,
    "loss_contour_multiplier": 1.0,
    "summarize_frequency": 128,
    "eval_frequency": 128,
    "max_num_steps": 50000
}

import pathlib
import random

import numpy as np

if USE_WANDB:
    try:
        import wandb
    except ModuleNotFoundError:
        !!pip install wandb
        import wandb

# Init
run_dir = pathlib.Path("piano_genie")
run_dir.mkdir(exist_ok=True)
with open(pathlib.Path(run_dir, "cfg.json"), "w") as f:
    f.write(json.dumps(CFG, indent=2))
if USE_WANDB:
    wandb.init(project="music-cocreation-tutorial", config=CFG, reinit=True)

# Set seed
if CFG["seed"] is not None:
    random.seed(CFG["seed"])
    np.random.seed(CFG["seed"])
    torch.manual_seed(CFG["seed"])
    torch.cuda.manual_seed_all(CFG["seed"])

# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.backends.mps.is_available():
    print("MPS available")
    device = torch.device("mps")
    
model = PianoGenieAutoencoder(CFG)
model.train()
model.to(device)
print("-" * 80)
for n, p in model.named_parameters():
    print(f"{n}, {p.shape}")

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])

# Subsamples performances to create a minibatch
def performances_to_batch(performances, device, train=True):
    batch_k = []
    batch_t = []
    batch_v = []
    for p in performances:
        # Subsample seq_len notes from performance
        assert len(p) >= CFG["seq_len"]
        if train:
            subsample_offset = random.randrange(0, len(p) - CFG["seq_len"])
        else:
            subsample_offset = 0
        subsample = p[subsample_offset : subsample_offset + CFG["seq_len"]]
        assert len(subsample) == CFG["seq_len"]

        # Data augmentation
        if train:
            stretch_factor = random.random() * CFG["data_augment_time_stretch_max"] * 2
            stretch_factor += 1 - CFG["data_augment_time_stretch_max"]
            transposition_factor = random.randint(
                -CFG["data_augment_transpose_max"], CFG["data_augment_transpose_max"]
            )
            subsample = [
                (
                    n[0] * stretch_factor,
                    n[1] * stretch_factor,
                    max(0, min(n[2] + transposition_factor, PIANO_NUM_KEYS - 1)),
                    n[3],
                )
                for n in subsample
            ]

        # Key features
        batch_k.append([n[2] for n in subsample])
        batch_v.append([n[3] for n in subsample])

        # Onset features
        # NOTE: For stability, we pass delta time to Piano Genie instead of time.
        t = np.diff([n[0] for n in subsample])
        t = np.concatenate([[1e8], t])
        t = np.clip(t, 0, CFG["data_delta_time_max"])
        batch_t.append(t)

    return (torch.tensor(batch_k).long(), torch.tensor(batch_t).float(), torch.tensor(batch_v).float())


# Train
step = 0
best_eval_loss = float("inf")
while CFG["max_num_steps"] is None or step < CFG["max_num_steps"]:
    if step % CFG["eval_frequency"] == 0:
        model.eval()

        with torch.no_grad():
            eval_losses_recons = []
            eval_violates_contour = []
            for i in range(0, len(DATASET["validation"]), CFG["batch_size"]):
                eval_batch = performances_to_batch(
                    DATASET["validation"][i : i + CFG["batch_size"]],
                    device,
                    train=False,
                )
                eval_k, eval_t, eval_v = tuple(t.to(device) for t in eval_batch)
                eval_hat_k, eval_e = model(eval_k, eval_t, eval_v)
                eval_b = model.quant.real_to_discrete(eval_e)
                eval_loss_recons = F.cross_entropy(
                    eval_hat_k.view(-1, PIANO_NUM_KEYS),
                    eval_k.view(-1),
                    reduction="none",
                )
                eval_violates = torch.logical_not(
                    torch.sign(torch.diff(eval_k, dim=1))
                    == torch.sign(torch.diff(eval_b, dim=1)),
                ).float()
                eval_violates_contour.extend(eval_violates.cpu().numpy().tolist())
                eval_losses_recons.extend(eval_loss_recons.cpu().numpy().tolist())

            eval_loss_recons = np.mean(eval_losses_recons)
            if eval_loss_recons < best_eval_loss:
                torch.save(model.state_dict(), pathlib.Path(run_dir, "model.pt"))
                best_eval_loss = eval_loss_recons

        eval_metrics = {
            "eval_loss_recons": eval_loss_recons,
            "eval_contour_violation_ratio": np.mean(eval_violates_contour),
        }
        if USE_WANDB:
            wandb.log(eval_metrics, step=step)
        print(step, "eval", eval_metrics)

        model.train()

    # Create minibatch
    batch = performances_to_batch(
        random.sample(DATASET["train"], CFG["batch_size"]), device, train=True
    )
    k, t, v = tuple(t.to(device) for t in batch)

    # Run model
    optimizer.zero_grad()
    k_hat, e = model(k, t, v)

    # Compute losses and update params
    loss_recons = F.cross_entropy(k_hat.view(-1, PIANO_NUM_KEYS), k.view(-1))
    loss_margin = torch.square(
        torch.maximum(torch.abs(e) - 1, torch.zeros_like(e))
    ).mean()
    loss_contour = torch.square(
        torch.maximum(
            1 - torch.diff(k, dim=1) * torch.diff(e, dim=1),
            torch.zeros_like(e[:, 1:]),
        )
    ).mean()
    loss = torch.zeros_like(loss_recons)
    loss += loss_recons
    if CFG["loss_margin_multiplier"] > 0:
        loss += CFG["loss_margin_multiplier"] * loss_margin
    if CFG["loss_contour_multiplier"] > 0:
        loss += CFG["loss_contour_multiplier"] * loss_contour
    loss.backward()
    optimizer.step()
    step += 1

    if step % CFG["summarize_frequency"] == 0:
        metrics = {
            "loss_recons": loss_recons.item(),
            "loss_margin": loss_margin.item(),
            "loss_contour": loss_contour.item(),
            "loss": loss.item(),
        }
        if USE_WANDB:
            wandb.log(metrics, step=step)
        print(step, "train", metrics)

# Download the trained model so we don't lose it!
# from google.colab import files

# files.download('piano_genie/model.pt')
# files.download('piano_genie/cfg.json')
--------------------------------------------------------------------------------
enc.input.weight, torch.Size([128, 90])
enc.input.bias, torch.Size([128])
enc.lstm.weight_ih_l0, torch.Size([512, 128])
enc.lstm.weight_hh_l0, torch.Size([512, 128])
enc.lstm.bias_ih_l0, torch.Size([512])
enc.lstm.bias_hh_l0, torch.Size([512])
enc.lstm.weight_ih_l0_reverse, torch.Size([512, 128])
enc.lstm.weight_hh_l0_reverse, torch.Size([512, 128])
enc.lstm.bias_ih_l0_reverse, torch.Size([512])
enc.lstm.bias_hh_l0_reverse, torch.Size([512])
enc.lstm.weight_ih_l1, torch.Size([512, 256])
enc.lstm.weight_hh_l1, torch.Size([512, 128])
enc.lstm.bias_ih_l1, torch.Size([512])
enc.lstm.bias_hh_l1, torch.Size([512])
enc.lstm.weight_ih_l1_reverse, torch.Size([512, 256])
enc.lstm.weight_hh_l1_reverse, torch.Size([512, 128])
enc.lstm.bias_ih_l1_reverse, torch.Size([512])
enc.lstm.bias_hh_l1_reverse, torch.Size([512])
enc.output.weight, torch.Size([1, 256])
enc.output.bias, torch.Size([1])
dec.input.weight, torch.Size([128, 92])
dec.input.bias, torch.Size([128])
dec.lstm.weight_ih_l0, torch.Size([512, 128])
dec.lstm.weight_hh_l0, torch.Size([512, 128])
dec.lstm.bias_ih_l0, torch.Size([512])
dec.lstm.bias_hh_l0, torch.Size([512])
dec.lstm.weight_ih_l1, torch.Size([512, 128])
dec.lstm.weight_hh_l1, torch.Size([512, 128])
dec.lstm.bias_ih_l1, torch.Size([512])
dec.lstm.bias_hh_l1, torch.Size([512])
dec.output.weight, torch.Size([88, 128])
dec.output.bias, torch.Size([88])
0 eval {'eval_loss_recons': np.float64(4.486702833814125), 'eval_contour_violation_ratio': np.float64(0.9079832174262888)}
128 train {'loss_recons': 4.045757293701172, 'loss_margin': 0.0, 'loss_contour': 0.3008792996406555, 'loss': 4.346636772155762}
128 eval {'eval_loss_recons': np.float64(3.9640611213951433), 'eval_contour_violation_ratio': np.float64(0.17575722742686362)}
256 train {'loss_recons': 4.03377628326416, 'loss_margin': 0.005535581149160862, 'loss_contour': 0.16457054018974304, 'loss': 4.203882217407227}
256 eval {'eval_loss_recons': np.float64(3.9295664395810697), 'eval_contour_violation_ratio': np.float64(0.04655439967814242)}
384 train {'loss_recons': 3.716930389404297, 'loss_margin': 0.022140074521303177, 'loss_contour': 0.12826721370220184, 'loss': 3.867337703704834}
384 eval {'eval_loss_recons': np.float64(3.5799948769573966), 'eval_contour_violation_ratio': np.float64(0.06029082131156963)}
512 train {'loss_recons': 3.287353992462158, 'loss_margin': 0.014329938217997551, 'loss_contour': 0.16425269842147827, 'loss': 3.4659366607666016}
512 eval {'eval_loss_recons': np.float64(3.2015453052917753), 'eval_contour_violation_ratio': np.float64(0.05764699120639117)}
640 train {'loss_recons': 3.088524580001831, 'loss_margin': 0.02659253403544426, 'loss_contour': 0.13125230371952057, 'loss': 3.2463693618774414}
640 eval {'eval_loss_recons': np.float64(2.9751736059866465), 'eval_contour_violation_ratio': np.float64(0.06086556698660842)}
768 train {'loss_recons': 2.848477363586426, 'loss_margin': 0.03840359300374985, 'loss_contour': 0.1337859034538269, 'loss': 3.0206668376922607}
768 eval {'eval_loss_recons': np.float64(2.7032569296535676), 'eval_contour_violation_ratio': np.float64(0.09943100178171159)}
896 train {'loss_recons': 2.701693058013916, 'loss_margin': 0.037429437041282654, 'loss_contour': 0.13206665217876434, 'loss': 2.8711891174316406}
896 eval {'eval_loss_recons': np.float64(2.5992382614037197), 'eval_contour_violation_ratio': np.float64(0.07339502270245417)}
1024 train {'loss_recons': 2.4553098678588867, 'loss_margin': 0.02616886980831623, 'loss_contour': 0.12079127877950668, 'loss': 2.6022698879241943}
1024 eval {'eval_loss_recons': np.float64(2.44766631360798), 'eval_contour_violation_ratio': np.float64(0.07017644692223692)}
1152 train {'loss_recons': 2.509087562561035, 'loss_margin': 0.0363595113158226, 'loss_contour': 0.16777727007865906, 'loss': 2.713224411010742}
1152 eval {'eval_loss_recons': np.float64(2.3444407584662312), 'eval_contour_violation_ratio': np.float64(0.07391229380998908)}
1280 train {'loss_recons': 2.2422025203704834, 'loss_margin': 0.037162892520427704, 'loss_contour': 0.10402068495750427, 'loss': 2.3833858966827393}
1280 eval {'eval_loss_recons': np.float64(2.1800529691016806), 'eval_contour_violation_ratio': np.float64(0.08552215644577274)}
1408 train {'loss_recons': 2.2662577629089355, 'loss_margin': 0.05142185091972351, 'loss_contour': 0.13315381109714508, 'loss': 2.450833559036255}
1408 eval {'eval_loss_recons': np.float64(2.137325938417148), 'eval_contour_violation_ratio': np.float64(0.08793608828093569)}
1536 train {'loss_recons': 2.161475658416748, 'loss_margin': 0.060236185789108276, 'loss_contour': 0.14116115868091583, 'loss': 2.362873077392578}
1536 eval {'eval_loss_recons': np.float64(2.0478662810267974), 'eval_contour_violation_ratio': np.float64(0.09586757859647106)}
1664 train {'loss_recons': 2.082367181777954, 'loss_margin': 0.05105502903461456, 'loss_contour': 0.1513219177722931, 'loss': 2.2847440242767334}
1664 eval {'eval_loss_recons': np.float64(1.9899188224022297), 'eval_contour_violation_ratio': np.float64(0.08408529225817575)}
1792 train {'loss_recons': 2.088503360748291, 'loss_margin': 0.07366366684436798, 'loss_contour': 0.12906353175640106, 'loss': 2.2912306785583496}
1792 eval {'eval_loss_recons': np.float64(1.9252681264952913), 'eval_contour_violation_ratio': np.float64(0.08943042703603656)}
1920 train {'loss_recons': 2.073911428451538, 'loss_margin': 0.08800145983695984, 'loss_contour': 0.1431877166032791, 'loss': 2.305100679397583}
1920 eval {'eval_loss_recons': np.float64(1.8808095674200431), 'eval_contour_violation_ratio': np.float64(0.08989022357606759)}
2048 train {'loss_recons': 2.039985179901123, 'loss_margin': 0.08667181432247162, 'loss_contour': 0.10114700347185135, 'loss': 2.227803945541382}
2048 eval {'eval_loss_recons': np.float64(1.9037467075008763), 'eval_contour_violation_ratio': np.float64(0.09759181562158745)}
2176 train {'loss_recons': 2.0135226249694824, 'loss_margin': 0.1125592589378357, 'loss_contour': 0.12975536286830902, 'loss': 2.2558372020721436}
2176 eval {'eval_loss_recons': np.float64(1.8036477549042362), 'eval_contour_violation_ratio': np.float64(0.08897063049600552)}
2304 train {'loss_recons': 1.8784253597259521, 'loss_margin': 0.08721557259559631, 'loss_contour': 0.15489044785499573, 'loss': 2.1205313205718994}
2304 eval {'eval_loss_recons': np.float64(1.7609738247047593), 'eval_contour_violation_ratio': np.float64(0.0873613426058969)}
2432 train {'loss_recons': 2.013939380645752, 'loss_margin': 0.10573558509349823, 'loss_contour': 0.16961482167243958, 'loss': 2.289289712905884}
2432 eval {'eval_loss_recons': np.float64(1.9170958312486646), 'eval_contour_violation_ratio': np.float64(0.08408529225817575)}
2560 train {'loss_recons': 1.8989622592926025, 'loss_margin': 0.12745696306228638, 'loss_contour': 0.15471376478672028, 'loss': 2.1811330318450928}
2560 eval {'eval_loss_recons': np.float64(1.7530651504418602), 'eval_contour_violation_ratio': np.float64(0.09333869762630036)}
2688 train {'loss_recons': 1.9046034812927246, 'loss_margin': 0.12125282734632492, 'loss_contour': 0.1334172636270523, 'loss': 2.159273624420166}
2688 eval {'eval_loss_recons': np.float64(2.083750153405687), 'eval_contour_violation_ratio': np.float64(0.0833955974481292)}
2816 train {'loss_recons': 1.8099160194396973, 'loss_margin': 0.09975746273994446, 'loss_contour': 0.17230945825576782, 'loss': 2.0819828510284424}
2816 eval {'eval_loss_recons': np.float64(1.6815270481216484), 'eval_contour_violation_ratio': np.float64(0.08402781769067187)}
2944 train {'loss_recons': 1.628737449645996, 'loss_margin': 0.10950950533151627, 'loss_contour': 0.14879833161830902, 'loss': 1.8870452642440796}
2944 eval {'eval_loss_recons': np.float64(1.65200920452289), 'eval_contour_violation_ratio': np.float64(0.08793608828093569)}
3072 train {'loss_recons': 1.908745527267456, 'loss_margin': 0.08145256340503693, 'loss_contour': 0.17408296465873718, 'loss': 2.164281129837036}
3072 eval {'eval_loss_recons': np.float64(1.5897158215155511), 'eval_contour_violation_ratio': np.float64(0.08655669866084258)}
3200 train {'loss_recons': 1.5739221572875977, 'loss_margin': 0.11387868225574493, 'loss_contour': 0.1532912701368332, 'loss': 1.8410921096801758}
3200 eval {'eval_loss_recons': np.float64(1.5784609286032998), 'eval_contour_violation_ratio': np.float64(0.08592447841829991)}
3328 train {'loss_recons': 1.7547698020935059, 'loss_margin': 0.16025133430957794, 'loss_contour': 0.1538713127374649, 'loss': 2.068892478942871}
3328 eval {'eval_loss_recons': np.float64(1.5933797719725238), 'eval_contour_violation_ratio': np.float64(0.08477498706822231)}
3456 train {'loss_recons': 1.8798270225524902, 'loss_margin': 0.21757765114307404, 'loss_contour': 0.1630021631717682, 'loss': 2.260406970977783}
3456 eval {'eval_loss_recons': np.float64(1.9329883502252454), 'eval_contour_violation_ratio': np.float64(0.08293580090809817)}
3584 train {'loss_recons': 2.064211845397949, 'loss_margin': 0.185592383146286, 'loss_contour': 0.14022651314735413, 'loss': 2.390030860900879}
3584 eval {'eval_loss_recons': np.float64(1.6268777469798488), 'eval_contour_violation_ratio': np.float64(0.09356859589631589)}
3712 train {'loss_recons': 1.612583875656128, 'loss_margin': 0.10940082371234894, 'loss_contour': 0.12168930470943451, 'loss': 1.8436740636825562}
3712 eval {'eval_loss_recons': np.float64(1.5600832280580526), 'eval_contour_violation_ratio': np.float64(0.08402781769067187)}
3840 train {'loss_recons': 1.5684406757354736, 'loss_margin': 0.12149357795715332, 'loss_contour': 0.10603947192430496, 'loss': 1.795973777770996}
3840 eval {'eval_loss_recons': np.float64(1.567133077002219), 'eval_contour_violation_ratio': np.float64(0.06695787114201966)}
3968 train {'loss_recons': 1.5611164569854736, 'loss_margin': 0.25837352871894836, 'loss_contour': 0.15077067911624908, 'loss': 1.9702606201171875}
3968 eval {'eval_loss_recons': np.float64(1.4642768024048856), 'eval_contour_violation_ratio': np.float64(0.07155583654233003)}
4096 train {'loss_recons': 1.4864790439605713, 'loss_margin': 0.30655398964881897, 'loss_contour': 0.1433219164609909, 'loss': 1.9363548755645752}
4096 eval {'eval_loss_recons': np.float64(1.4645303429043206), 'eval_contour_violation_ratio': np.float64(0.08552215644577274)}
4224 train {'loss_recons': 1.471989631652832, 'loss_margin': 0.1621730625629425, 'loss_contour': 0.13033008575439453, 'loss': 1.7644927501678467}
4224 eval {'eval_loss_recons': np.float64(1.4189739567346626), 'eval_contour_violation_ratio': np.float64(0.07276280245991149)}
4352 train {'loss_recons': 1.565762996673584, 'loss_margin': 0.22840626537799835, 'loss_contour': 0.10845804214477539, 'loss': 1.9026273488998413}
4352 eval {'eval_loss_recons': np.float64(1.8246695847236925), 'eval_contour_violation_ratio': np.float64(0.07057876889476407)}
4480 train {'loss_recons': 1.6475613117218018, 'loss_margin': 0.3762543797492981, 'loss_contour': 0.15919066965579987, 'loss': 2.1830062866210938}
4480 eval {'eval_loss_recons': np.float64(1.5658380993298615), 'eval_contour_violation_ratio': np.float64(0.07506178516006667)}
4608 train {'loss_recons': 1.4232971668243408, 'loss_margin': 0.23658373951911926, 'loss_contour': 0.1199691966176033, 'loss': 1.7798501253128052}
4608 eval {'eval_loss_recons': np.float64(1.4757614407423258), 'eval_contour_violation_ratio': np.float64(0.07448703948502787)}
4736 train {'loss_recons': 1.7338950634002686, 'loss_margin': 0.2923099100589752, 'loss_contour': 0.15047283470630646, 'loss': 2.176677942276001}
4736 eval {'eval_loss_recons': np.float64(1.6295470720680472), 'eval_contour_violation_ratio': np.float64(0.07661359848267142)}
4864 train {'loss_recons': 1.452807903289795, 'loss_margin': 0.24675171077251434, 'loss_contour': 0.15831735730171204, 'loss': 1.8578768968582153}
4864 eval {'eval_loss_recons': np.float64(1.3921173777070956), 'eval_contour_violation_ratio': np.float64(0.07431461578251623)}
4992 train {'loss_recons': 1.4999351501464844, 'loss_margin': 0.35676443576812744, 'loss_contour': 0.15845543146133423, 'loss': 2.015155076980591}
4992 eval {'eval_loss_recons': np.float64(1.534869652360433), 'eval_contour_violation_ratio': np.float64(0.07678602218518306)}
5120 train {'loss_recons': 1.5338282585144043, 'loss_margin': 0.3379828929901123, 'loss_contour': 0.1408671885728836, 'loss': 2.012678384780884}
5120 eval {'eval_loss_recons': np.float64(1.380197944279325), 'eval_contour_violation_ratio': np.float64(0.07534915799758607)}
5248 train {'loss_recons': 1.4508225917816162, 'loss_margin': 0.7804330587387085, 'loss_contour': 0.11559084057807922, 'loss': 2.346846342086792}
5248 eval {'eval_loss_recons': np.float64(1.3084713271037287), 'eval_contour_violation_ratio': np.float64(0.07391229380998908)}
5376 train {'loss_recons': 1.4734809398651123, 'loss_margin': 0.5535368919372559, 'loss_contour': 0.1847905069589615, 'loss': 2.211808443069458}
5376 eval {'eval_loss_recons': np.float64(1.3241865348689477), 'eval_contour_violation_ratio': np.float64(0.06506121041439163)}
5504 train {'loss_recons': 1.4334495067596436, 'loss_margin': 0.31113165616989136, 'loss_contour': 0.16205523908138275, 'loss': 1.9066364765167236}
5504 eval {'eval_loss_recons': np.float64(1.2884144772313209), 'eval_contour_violation_ratio': np.float64(0.061210414391631704)}
5632 train {'loss_recons': 1.6127394437789917, 'loss_margin': 0.7549182772636414, 'loss_contour': 0.1321142017841339, 'loss': 2.4997718334198}
5632 eval {'eval_loss_recons': np.float64(1.33852285899589), 'eval_contour_violation_ratio': np.float64(0.06350939709178688)}
5760 train {'loss_recons': 1.3609870672225952, 'loss_margin': 0.3955615162849426, 'loss_contour': 0.13746277987957, 'loss': 1.894011378288269}
5760 eval {'eval_loss_recons': np.float64(1.2483288102280212), 'eval_contour_violation_ratio': np.float64(0.0631645496867636)}
5888 train {'loss_recons': 1.5010325908660889, 'loss_margin': 0.3259492516517639, 'loss_contour': 0.18166011571884155, 'loss': 2.0086419582366943}
5888 eval {'eval_loss_recons': np.float64(1.3550940237048823), 'eval_contour_violation_ratio': np.float64(0.059256279096499796)}
6016 train {'loss_recons': 1.8305778503417969, 'loss_margin': 0.4306517541408539, 'loss_contour': 0.18726365268230438, 'loss': 2.448493242263794}
6016 eval {'eval_loss_recons': np.float64(1.414220527874945), 'eval_contour_violation_ratio': np.float64(0.06103799068912007)}
6144 train {'loss_recons': 1.53910231590271, 'loss_margin': 0.6993337869644165, 'loss_contour': 0.15320409834384918, 'loss': 2.3916404247283936}
6144 eval {'eval_loss_recons': np.float64(1.4488481925630494), 'eval_contour_violation_ratio': np.float64(0.06098051612161619)}
6272 train {'loss_recons': 1.6230546236038208, 'loss_margin': 1.0182437896728516, 'loss_contour': 0.15041562914848328, 'loss': 2.7917139530181885}
6272 eval {'eval_loss_recons': np.float64(1.5010266807576111), 'eval_contour_violation_ratio': np.float64(0.07006149778722916)}
6400 train {'loss_recons': 1.4945621490478516, 'loss_margin': 0.48849159479141235, 'loss_contour': 0.14237532019615173, 'loss': 2.1254289150238037}
6400 eval {'eval_loss_recons': np.float64(1.441297914411356), 'eval_contour_violation_ratio': np.float64(0.06425656646933732)}
6528 train {'loss_recons': 1.3049036264419556, 'loss_margin': 0.1982090175151825, 'loss_contour': 0.13851682841777802, 'loss': 1.6416294574737549}
6528 eval {'eval_loss_recons': np.float64(1.2423460108812654), 'eval_contour_violation_ratio': np.float64(0.060693143284096786)}
6656 train {'loss_recons': 1.2830111980438232, 'loss_margin': 0.38746166229248047, 'loss_contour': 0.14374491572380066, 'loss': 1.8142178058624268}
6656 eval {'eval_loss_recons': np.float64(1.1936567761530814), 'eval_contour_violation_ratio': np.float64(0.06350939709178688)}
6784 train {'loss_recons': 1.193464994430542, 'loss_margin': 0.49111682176589966, 'loss_contour': 0.12262251228094101, 'loss': 1.807204246520996}
6784 eval {'eval_loss_recons': np.float64(1.220692803253875), 'eval_contour_violation_ratio': np.float64(0.07172826024484166)}
6912 train {'loss_recons': 1.2728822231292725, 'loss_margin': 0.7990095615386963, 'loss_contour': 0.13828043639659882, 'loss': 2.210172176361084}
6912 eval {'eval_loss_recons': np.float64(1.194640753438941), 'eval_contour_violation_ratio': np.float64(0.06672797287200413)}
7040 train {'loss_recons': 1.2479420900344849, 'loss_margin': 0.5012987852096558, 'loss_contour': 0.1385989487171173, 'loss': 1.8878397941589355}
7040 eval {'eval_loss_recons': np.float64(1.1707357297529106), 'eval_contour_violation_ratio': np.float64(0.06965917581470199)}
7168 train {'loss_recons': 1.273751974105835, 'loss_margin': 0.5669443011283875, 'loss_contour': 0.16240815818309784, 'loss': 2.0031044483184814}
7168 eval {'eval_loss_recons': np.float64(1.1574766598539903), 'eval_contour_violation_ratio': np.float64(0.06839473532961664)}
7296 train {'loss_recons': 1.496160864830017, 'loss_margin': 0.5128216743469238, 'loss_contour': 0.1453297883272171, 'loss': 2.1543123722076416}
7296 eval {'eval_loss_recons': np.float64(1.5688825468549736), 'eval_contour_violation_ratio': np.float64(0.06500373584688775)}
7424 train {'loss_recons': 1.2524161338806152, 'loss_margin': 0.45981118083000183, 'loss_contour': 0.1336730420589447, 'loss': 1.8459004163742065}
7424 eval {'eval_loss_recons': np.float64(1.191109875616461), 'eval_contour_violation_ratio': np.float64(0.061095465256623946)}
7552 train {'loss_recons': 1.3189947605133057, 'loss_margin': 0.8731250762939453, 'loss_contour': 0.15777412056922913, 'loss': 2.3498940467834473}
7552 eval {'eval_loss_recons': np.float64(1.331331677656675), 'eval_contour_violation_ratio': np.float64(0.06333697338927524)}
7680 train {'loss_recons': 1.3302878141403198, 'loss_margin': 0.9578675627708435, 'loss_contour': 0.14988194406032562, 'loss': 2.438037157058716}
7680 eval {'eval_loss_recons': np.float64(1.2309690890221918), 'eval_contour_violation_ratio': np.float64(0.06253232944422094)}
7808 train {'loss_recons': 1.2198052406311035, 'loss_margin': 1.0714242458343506, 'loss_contour': 0.20691581070423126, 'loss': 2.498145341873169}
7808 eval {'eval_loss_recons': np.float64(1.24474268139979), 'eval_contour_violation_ratio': np.float64(0.06322202425426748)}
7936 train {'loss_recons': 1.2710654735565186, 'loss_margin': 0.8454201221466064, 'loss_contour': 0.1613730639219284, 'loss': 2.2778587341308594}
7936 eval {'eval_loss_recons': np.float64(1.1350598736958455), 'eval_contour_violation_ratio': np.float64(0.06258980401172481)}
8064 train {'loss_recons': 1.2768974304199219, 'loss_margin': 0.7691138982772827, 'loss_contour': 0.1766616702079773, 'loss': 2.222673177719116}
8064 eval {'eval_loss_recons': np.float64(1.117422253493369), 'eval_contour_violation_ratio': np.float64(0.05960112650152308)}
8192 train {'loss_recons': 1.2852106094360352, 'loss_margin': 0.5811178088188171, 'loss_contour': 0.15394362807273865, 'loss': 2.0202720165252686}
8192 eval {'eval_loss_recons': np.float64(1.1450774707885827), 'eval_contour_violation_ratio': np.float64(0.06138283809414334)}
8320 train {'loss_recons': 1.2644710540771484, 'loss_margin': 0.7607988715171814, 'loss_contour': 0.17214882373809814, 'loss': 2.197418689727783}
8320 eval {'eval_loss_recons': np.float64(1.3063366930901215), 'eval_contour_violation_ratio': np.float64(0.05523305937122823)}
8448 train {'loss_recons': 1.824556827545166, 'loss_margin': 1.2107374668121338, 'loss_contour': 0.14937083423137665, 'loss': 3.1846652030944824}
8448 eval {'eval_loss_recons': np.float64(1.5704610521266114), 'eval_contour_violation_ratio': np.float64(0.06281970228174033)}
8576 train {'loss_recons': 1.2954638004302979, 'loss_margin': 0.7291122674942017, 'loss_contour': 0.18136629462242126, 'loss': 2.205942392349243}
8576 eval {'eval_loss_recons': np.float64(1.1209353409235088), 'eval_contour_violation_ratio': np.float64(0.06230243117420541)}
8704 train {'loss_recons': 1.1669384241104126, 'loss_margin': 0.6723990440368652, 'loss_contour': 0.16827492415905, 'loss': 2.007612466812134}
8704 eval {'eval_loss_recons': np.float64(1.0960960054487996), 'eval_contour_violation_ratio': np.float64(0.058336686016437725)}
8832 train {'loss_recons': 1.2091357707977295, 'loss_margin': 1.1053025722503662, 'loss_contour': 0.18938089907169342, 'loss': 2.503819227218628}
8832 eval {'eval_loss_recons': np.float64(1.1428438398450014), 'eval_contour_violation_ratio': np.float64(0.05908385539398816)}
8960 train {'loss_recons': 1.2272248268127441, 'loss_margin': 1.204535961151123, 'loss_contour': 0.23579391837120056, 'loss': 2.6675546169281006}
8960 eval {'eval_loss_recons': np.float64(1.0905550887713031), 'eval_contour_violation_ratio': np.float64(0.05862405885395713)}
9088 train {'loss_recons': 1.189490556716919, 'loss_margin': 1.2103466987609863, 'loss_contour': 0.16955773532390594, 'loss': 2.569395065307617}
9088 eval {'eval_loss_recons': np.float64(1.1532915274945705), 'eval_contour_violation_ratio': np.float64(0.06161273636415886)}
9216 train {'loss_recons': 1.2394051551818848, 'loss_margin': 1.1048643589019775, 'loss_contour': 0.14072172343730927, 'loss': 2.4849913120269775}
9216 eval {'eval_loss_recons': np.float64(1.0854967724681444), 'eval_contour_violation_ratio': np.float64(0.06126788895913558)}
9344 train {'loss_recons': 1.2303688526153564, 'loss_margin': 1.2150816917419434, 'loss_contour': 0.1384437382221222, 'loss': 2.5838942527770996}
9344 eval {'eval_loss_recons': np.float64(1.1404317114757117), 'eval_contour_violation_ratio': np.float64(0.0659808034944537)}
9472 train {'loss_recons': 1.3382995128631592, 'loss_margin': 1.3804070949554443, 'loss_contour': 0.15329276025295258, 'loss': 2.8719992637634277}
9472 eval {'eval_loss_recons': np.float64(1.2704929181265843), 'eval_contour_violation_ratio': np.float64(0.06442899017184896)}
9600 train {'loss_recons': 1.1961256265640259, 'loss_margin': 0.6577322483062744, 'loss_contour': 0.17415858805179596, 'loss': 2.0280165672302246}
9600 eval {'eval_loss_recons': np.float64(1.0871053928353), 'eval_contour_violation_ratio': np.float64(0.06000344847405023)}
9728 train {'loss_recons': 1.248963475227356, 'loss_margin': 1.927502989768982, 'loss_contour': 0.15713156759738922, 'loss': 3.3335981369018555}
9728 eval {'eval_loss_recons': np.float64(1.082449566732657), 'eval_contour_violation_ratio': np.float64(0.06006092304155411)}
9856 train {'loss_recons': 1.4076191186904907, 'loss_margin': 1.1047818660736084, 'loss_contour': 0.19925634562969208, 'loss': 2.7116575241088867}
9856 eval {'eval_loss_recons': np.float64(1.231125801686116), 'eval_contour_violation_ratio': np.float64(0.06270475314673257)}
9984 train {'loss_recons': 1.2447779178619385, 'loss_margin': 1.338110327720642, 'loss_contour': 0.18459223210811615, 'loss': 2.7674803733825684}
9984 eval {'eval_loss_recons': np.float64(1.1161456618043357), 'eval_contour_violation_ratio': np.float64(0.06552100695442267)}
10112 train {'loss_recons': 1.1959927082061768, 'loss_margin': 1.1131417751312256, 'loss_contour': 0.14267878234386444, 'loss': 2.451813220977783}
10112 eval {'eval_loss_recons': np.float64(1.1645754661162921), 'eval_contour_violation_ratio': np.float64(0.06437151560434508)}
10240 train {'loss_recons': 1.2315186262130737, 'loss_margin': 1.3497471809387207, 'loss_contour': 0.14475677907466888, 'loss': 2.726022720336914}
10240 eval {'eval_loss_recons': np.float64(1.0711737768163727), 'eval_contour_violation_ratio': np.float64(0.05678487269383298)}
10368 train {'loss_recons': 1.1651448011398315, 'loss_margin': 1.3755276203155518, 'loss_contour': 0.20437751710414886, 'loss': 2.7450497150421143}
10368 eval {'eval_loss_recons': np.float64(1.0782869124689893), 'eval_contour_violation_ratio': np.float64(0.06759009138456233)}
10496 train {'loss_recons': 1.1733551025390625, 'loss_margin': 1.534288763999939, 'loss_contour': 0.2103395164012909, 'loss': 2.9179835319519043}
10496 eval {'eval_loss_recons': np.float64(1.084725385737422), 'eval_contour_violation_ratio': np.float64(0.0626472785792287)}
10624 train {'loss_recons': 1.332183599472046, 'loss_margin': 1.6243797540664673, 'loss_contour': 0.10933984071016312, 'loss': 3.065903425216675}
10624 eval {'eval_loss_recons': np.float64(1.0877673682921574), 'eval_contour_violation_ratio': np.float64(0.06425656646933732)}
10752 train {'loss_recons': 1.0756598711013794, 'loss_margin': 0.5734738707542419, 'loss_contour': 0.16303323209285736, 'loss': 1.8121669292449951}
10752 eval {'eval_loss_recons': np.float64(1.1510848930057982), 'eval_contour_violation_ratio': np.float64(0.05787688947640669)}
10880 train {'loss_recons': 1.157747745513916, 'loss_margin': 1.803996205329895, 'loss_contour': 0.17632725834846497, 'loss': 3.138071060180664}
10880 eval {'eval_loss_recons': np.float64(1.038467889551386), 'eval_contour_violation_ratio': np.float64(0.06804988792459336)}
11008 train {'loss_recons': 1.1761467456817627, 'loss_margin': 1.2141618728637695, 'loss_contour': 0.16105639934539795, 'loss': 2.5513648986816406}
11008 eval {'eval_loss_recons': np.float64(1.2223591467554267), 'eval_contour_violation_ratio': np.float64(0.06235990574170929)}
11136 train {'loss_recons': 1.073381781578064, 'loss_margin': 1.9454796314239502, 'loss_contour': 0.13710853457450867, 'loss': 3.1559698581695557}
11136 eval {'eval_loss_recons': np.float64(1.0805975089587243), 'eval_contour_violation_ratio': np.float64(0.06563595608943043)}
11264 train {'loss_recons': 1.3080370426177979, 'loss_margin': 0.3952547609806061, 'loss_contour': 0.1470269411802292, 'loss': 1.8503186702728271}
11264 eval {'eval_loss_recons': np.float64(1.2300101584171959), 'eval_contour_violation_ratio': np.float64(0.05862405885395713)}
11392 train {'loss_recons': 1.4518460035324097, 'loss_margin': 2.1876578330993652, 'loss_contour': 0.18455389142036438, 'loss': 3.8240578174591064}
11392 eval {'eval_loss_recons': np.float64(1.3841446257480183), 'eval_contour_violation_ratio': np.float64(0.06919937927467096)}
11520 train {'loss_recons': 1.1147150993347168, 'loss_margin': 1.0567655563354492, 'loss_contour': 0.13926257193088531, 'loss': 2.3107433319091797}
11520 eval {'eval_loss_recons': np.float64(1.0587525254304107), 'eval_contour_violation_ratio': np.float64(0.06138283809414334)}
11648 train {'loss_recons': 1.1628446578979492, 'loss_margin': 1.375784158706665, 'loss_contour': 0.17198455333709717, 'loss': 2.710613250732422}
11648 eval {'eval_loss_recons': np.float64(1.0347983852941476), 'eval_contour_violation_ratio': np.float64(0.0609230415541123)}
11776 train {'loss_recons': 1.2945947647094727, 'loss_margin': 0.9587796330451965, 'loss_contour': 0.18775923550128937, 'loss': 2.441133499145508}
11776 eval {'eval_loss_recons': np.float64(1.0500954689422324), 'eval_contour_violation_ratio': np.float64(0.06126788895913558)}
11904 train {'loss_recons': 1.3315140008926392, 'loss_margin': 3.9897990226745605, 'loss_contour': 0.19373579323291779, 'loss': 5.515048503875732}
11904 eval {'eval_loss_recons': np.float64(1.3643479088761497), 'eval_contour_violation_ratio': np.float64(0.06270475314673257)}
12032 train {'loss_recons': 1.4372179508209229, 'loss_margin': 1.6679103374481201, 'loss_contour': 0.16903156042099, 'loss': 3.2741599082946777}
12032 eval {'eval_loss_recons': np.float64(1.2452905334610438), 'eval_contour_violation_ratio': np.float64(0.06333697338927524)}
12160 train {'loss_recons': 1.2868154048919678, 'loss_margin': 3.1132874488830566, 'loss_contour': 0.20810748636722565, 'loss': 4.60821008682251}
12160 eval {'eval_loss_recons': np.float64(1.11500995193292), 'eval_contour_violation_ratio': np.float64(0.062130007471693775)}
12288 train {'loss_recons': 1.4165492057800293, 'loss_margin': 1.2708696126937866, 'loss_contour': 0.2290189564228058, 'loss': 2.916437864303589}
12288 eval {'eval_loss_recons': np.float64(1.5888524230095442), 'eval_contour_violation_ratio': np.float64(0.0659808034944537)}
12416 train {'loss_recons': 1.1622860431671143, 'loss_margin': 0.9161436557769775, 'loss_contour': 0.20529621839523315, 'loss': 2.2837259769439697}
12416 eval {'eval_loss_recons': np.float64(1.0634439470786874), 'eval_contour_violation_ratio': np.float64(0.06408414276682568)}
12544 train {'loss_recons': 1.314828634262085, 'loss_margin': 0.4607471227645874, 'loss_contour': 0.19939455389976501, 'loss': 1.9749703407287598}
12544 eval {'eval_loss_recons': np.float64(1.3346671599378115), 'eval_contour_violation_ratio': np.float64(0.0640266681993218)}
12672 train {'loss_recons': 1.101676344871521, 'loss_margin': 1.276550531387329, 'loss_contour': 0.19595348834991455, 'loss': 2.5741801261901855}
12672 eval {'eval_loss_recons': np.float64(1.0323276060818607), 'eval_contour_violation_ratio': np.float64(0.06132536352663946)}
12800 train {'loss_recons': 1.3100910186767578, 'loss_margin': 2.3412652015686035, 'loss_contour': 0.1703515499830246, 'loss': 3.8217077255249023}
12800 eval {'eval_loss_recons': np.float64(1.1904707727647883), 'eval_contour_violation_ratio': np.float64(0.058739007988964885)}
12928 train {'loss_recons': 1.1050410270690918, 'loss_margin': 2.3777832984924316, 'loss_contour': 0.17427898943424225, 'loss': 3.6571033000946045}
12928 eval {'eval_loss_recons': np.float64(1.0226977344170676), 'eval_contour_violation_ratio': np.float64(0.05948617736651532)}
13056 train {'loss_recons': 1.2292506694793701, 'loss_margin': 1.7295105457305908, 'loss_contour': 0.1496221274137497, 'loss': 3.1083834171295166}
13056 eval {'eval_loss_recons': np.float64(1.0358264409167093), 'eval_contour_violation_ratio': np.float64(0.05960112650152308)}
13184 train {'loss_recons': 1.3063608407974243, 'loss_margin': 0.8358551263809204, 'loss_contour': 0.19795887172222137, 'loss': 2.340174913406372}
13184 eval {'eval_loss_recons': np.float64(1.177675417003249), 'eval_contour_violation_ratio': np.float64(0.06667049830450025)}
13312 train {'loss_recons': 1.218508243560791, 'loss_margin': 1.5290470123291016, 'loss_contour': 0.196940079331398, 'loss': 2.944495439529419}
13312 eval {'eval_loss_recons': np.float64(1.0679617296465855), 'eval_contour_violation_ratio': np.float64(0.06046324501408127)}
13440 train {'loss_recons': 1.1958668231964111, 'loss_margin': 0.39127254486083984, 'loss_contour': 0.18946678936481476, 'loss': 1.7766062021255493}
13440 eval {'eval_loss_recons': np.float64(1.069523298283563), 'eval_contour_violation_ratio': np.float64(0.061095465256623946)}
13568 train {'loss_recons': 1.3553422689437866, 'loss_margin': 1.2734562158584595, 'loss_contour': 0.20031189918518066, 'loss': 2.8291103839874268}
13568 eval {'eval_loss_recons': np.float64(1.2638915872964396), 'eval_contour_violation_ratio': np.float64(0.06517615954939938)}
13696 train {'loss_recons': 1.1772730350494385, 'loss_margin': 1.9443602561950684, 'loss_contour': 0.17804129421710968, 'loss': 3.2996745109558105}
13696 eval {'eval_loss_recons': np.float64(1.097918842395302), 'eval_contour_violation_ratio': np.float64(0.0609230415541123)}
13824 train {'loss_recons': 1.1784807443618774, 'loss_margin': 2.1772003173828125, 'loss_contour': 0.18641012907028198, 'loss': 3.542091131210327}
13824 eval {'eval_loss_recons': np.float64(1.0222115124980273), 'eval_contour_violation_ratio': np.float64(0.06241738030921317)}
13952 train {'loss_recons': 1.2697362899780273, 'loss_margin': 1.89699125289917, 'loss_contour': 0.2674323320388794, 'loss': 3.434159755706787}
13952 eval {'eval_loss_recons': np.float64(1.1130700852512891), 'eval_contour_violation_ratio': np.float64(0.062072532904189896)}
14080 train {'loss_recons': 1.2033973932266235, 'loss_margin': 2.353849172592163, 'loss_contour': 0.17383794486522675, 'loss': 3.7310845851898193}
14080 eval {'eval_loss_recons': np.float64(1.068772615851277), 'eval_contour_violation_ratio': np.float64(0.06184263463417438)}
14208 train {'loss_recons': 1.7634483575820923, 'loss_margin': 3.322148084640503, 'loss_contour': 0.1994449347257614, 'loss': 5.285041332244873}
14208 eval {'eval_loss_recons': np.float64(1.8002721051078794), 'eval_contour_violation_ratio': np.float64(0.0679349387895856)}
14336 train {'loss_recons': 1.3343085050582886, 'loss_margin': 1.3748563528060913, 'loss_contour': 0.3198881149291992, 'loss': 3.029052972793579}
14336 eval {'eval_loss_recons': np.float64(1.0150405961525606), 'eval_contour_violation_ratio': np.float64(0.05816426231392609)}
14464 train {'loss_recons': 1.0953130722045898, 'loss_margin': 1.0906471014022827, 'loss_contour': 0.17875391244888306, 'loss': 2.3647141456604004}
14464 eval {'eval_loss_recons': np.float64(1.0948127064786777), 'eval_contour_violation_ratio': np.float64(0.05942870279901144)}
14592 train {'loss_recons': 1.079927921295166, 'loss_margin': 0.8936275243759155, 'loss_contour': 0.14589910209178925, 'loss': 2.1194546222686768}
14592 eval {'eval_loss_recons': np.float64(1.1100115905741883), 'eval_contour_violation_ratio': np.float64(0.061095465256623946)}
14720 train {'loss_recons': 1.3730583190917969, 'loss_margin': 0.7135168313980103, 'loss_contour': 0.18806034326553345, 'loss': 2.2746353149414062}
14720 eval {'eval_loss_recons': np.float64(1.2374876923803886), 'eval_contour_violation_ratio': np.float64(0.06770504051957009)}
14848 train {'loss_recons': 1.1077237129211426, 'loss_margin': 1.0644716024398804, 'loss_contour': 0.13876883685588837, 'loss': 2.310964345932007}
14848 eval {'eval_loss_recons': np.float64(1.0949543620413849), 'eval_contour_violation_ratio': np.float64(0.06201505833668602)}
14976 train {'loss_recons': 1.2438335418701172, 'loss_margin': 1.972895860671997, 'loss_contour': 0.23163998126983643, 'loss': 3.4483695030212402}
14976 eval {'eval_loss_recons': np.float64(1.3137814956296567), 'eval_contour_violation_ratio': np.float64(0.0659808034944537)}
15104 train {'loss_recons': 1.3862242698669434, 'loss_margin': 2.1694421768188477, 'loss_contour': 0.1965187042951584, 'loss': 3.752185106277466}
15104 eval {'eval_loss_recons': np.float64(1.19012457384077), 'eval_contour_violation_ratio': np.float64(0.06408414276682568)}
15232 train {'loss_recons': 1.2877581119537354, 'loss_margin': 2.241480588912964, 'loss_contour': 0.20525984466075897, 'loss': 3.7344985008239746}
15232 eval {'eval_loss_recons': np.float64(1.2698977580514947), 'eval_contour_violation_ratio': np.float64(0.06253232944422094)}
15360 train {'loss_recons': 1.166126012802124, 'loss_margin': 2.0545248985290527, 'loss_contour': 0.22963333129882812, 'loss': 3.450284242630005}
15360 eval {'eval_loss_recons': np.float64(1.2056938207994663), 'eval_contour_violation_ratio': np.float64(0.05983102477153859)}
15488 train {'loss_recons': 1.1294734477996826, 'loss_margin': 1.5724643468856812, 'loss_contour': 0.1919073909521103, 'loss': 2.8938450813293457}
15488 eval {'eval_loss_recons': np.float64(1.088832161195974), 'eval_contour_violation_ratio': np.float64(0.06241738030921317)}
15616 train {'loss_recons': 1.240647315979004, 'loss_margin': 0.8361570239067078, 'loss_contour': 0.29818400740623474, 'loss': 2.374988317489624}
15616 eval {'eval_loss_recons': np.float64(1.0646979902942333), 'eval_contour_violation_ratio': np.float64(0.05948617736651532)}
15744 train {'loss_recons': 1.451066493988037, 'loss_margin': 1.3525240421295166, 'loss_contour': 0.16578733921051025, 'loss': 2.9693779945373535}
15744 eval {'eval_loss_recons': np.float64(1.2347948203661248), 'eval_contour_violation_ratio': np.float64(0.0668429220070119)}
15872 train {'loss_recons': 1.24824059009552, 'loss_margin': 0.9121982455253601, 'loss_contour': 0.16894729435443878, 'loss': 2.329385995864868}
15872 eval {'eval_loss_recons': np.float64(1.154462764391431), 'eval_contour_violation_ratio': np.float64(0.07259037875739985)}
16000 train {'loss_recons': 1.2896029949188232, 'loss_margin': 2.7943220138549805, 'loss_contour': 0.2519686818122864, 'loss': 4.3358941078186035}
16000 eval {'eval_loss_recons': np.float64(1.0952872713491573), 'eval_contour_violation_ratio': np.float64(0.06477383757687223)}
16128 train {'loss_recons': 1.1958979368209839, 'loss_margin': 1.202786922454834, 'loss_contour': 0.17444609105587006, 'loss': 2.5731310844421387}
16128 eval {'eval_loss_recons': np.float64(1.0404767261675294), 'eval_contour_violation_ratio': np.float64(0.061095465256623946)}
16256 train {'loss_recons': 1.5043199062347412, 'loss_margin': 2.3484973907470703, 'loss_contour': 0.1789686530828476, 'loss': 4.03178596496582}
16256 eval {'eval_loss_recons': np.float64(1.2618010542228015), 'eval_contour_violation_ratio': np.float64(0.06437151560434508)}
16384 train {'loss_recons': 1.1160519123077393, 'loss_margin': 1.7635185718536377, 'loss_contour': 0.1518888920545578, 'loss': 3.031459331512451}
16384 eval {'eval_loss_recons': np.float64(1.0341447875614067), 'eval_contour_violation_ratio': np.float64(0.06029082131156963)}
16512 train {'loss_recons': 1.0907926559448242, 'loss_margin': 1.6139506101608276, 'loss_contour': 0.22422447800636292, 'loss': 2.9289679527282715}
16512 eval {'eval_loss_recons': np.float64(1.037352278871958), 'eval_contour_violation_ratio': np.float64(0.060808092419104544)}
16640 train {'loss_recons': 1.2648897171020508, 'loss_margin': 0.7975267171859741, 'loss_contour': 0.17555133998394012, 'loss': 2.2379679679870605}
16640 eval {'eval_loss_recons': np.float64(1.3643842255730645), 'eval_contour_violation_ratio': np.float64(0.061152939824127825)}
16768 train {'loss_recons': 1.128448724746704, 'loss_margin': 1.2638787031173706, 'loss_contour': 0.16111800074577332, 'loss': 2.553445339202881}
16768 eval {'eval_loss_recons': np.float64(1.0522234324972572), 'eval_contour_violation_ratio': np.float64(0.06896948100465544)}
16896 train {'loss_recons': 1.0614567995071411, 'loss_margin': 0.7710686922073364, 'loss_contour': 0.1799621880054474, 'loss': 2.0124876499176025}
16896 eval {'eval_loss_recons': np.float64(1.0427880975784034), 'eval_contour_violation_ratio': np.float64(0.06086556698660842)}
17024 train {'loss_recons': 1.209606409072876, 'loss_margin': 2.97586727142334, 'loss_contour': 0.20657940208911896, 'loss': 4.39205265045166}
17024 eval {'eval_loss_recons': np.float64(1.2267695577737847), 'eval_contour_violation_ratio': np.float64(0.0614977872291511)}
17152 train {'loss_recons': 1.1898858547210693, 'loss_margin': 2.418959140777588, 'loss_contour': 0.1749677062034607, 'loss': 3.7838127613067627}
17152 eval {'eval_loss_recons': np.float64(1.0345244172584083), 'eval_contour_violation_ratio': np.float64(0.059716075636530835)}
17280 train {'loss_recons': 1.0688166618347168, 'loss_margin': 2.931882381439209, 'loss_contour': 0.1635853499174118, 'loss': 4.1642842292785645}
17280 eval {'eval_loss_recons': np.float64(1.0572011502562464), 'eval_contour_violation_ratio': np.float64(0.05983102477153859)}
17408 train {'loss_recons': 1.1352788209915161, 'loss_margin': 2.995659828186035, 'loss_contour': 0.15246126055717468, 'loss': 4.28339958190918}
17408 eval {'eval_loss_recons': np.float64(0.9992530782580423), 'eval_contour_violation_ratio': np.float64(0.060808092419104544)}
17536 train {'loss_recons': 1.0584660768508911, 'loss_margin': 0.7338616847991943, 'loss_contour': 0.2414921522140503, 'loss': 2.0338199138641357}
17536 eval {'eval_loss_recons': np.float64(1.0444704940458454), 'eval_contour_violation_ratio': np.float64(0.05431346629116616)}
17664 train {'loss_recons': 1.2552227973937988, 'loss_margin': 2.360670566558838, 'loss_contour': 0.14336775243282318, 'loss': 3.759261131286621}
17664 eval {'eval_loss_recons': np.float64(1.1427821966842409), 'eval_contour_violation_ratio': np.float64(0.05891143169147652)}
17792 train {'loss_recons': 1.0680075883865356, 'loss_margin': 1.2486828565597534, 'loss_contour': 0.16886842250823975, 'loss': 2.4855589866638184}
17792 eval {'eval_loss_recons': np.float64(1.0693804630011245), 'eval_contour_violation_ratio': np.float64(0.05937122823150756)}
17920 train {'loss_recons': 1.1822257041931152, 'loss_margin': 1.4551283121109009, 'loss_contour': 0.1925782859325409, 'loss': 2.82993221282959}
17920 eval {'eval_loss_recons': np.float64(1.0576771089265864), 'eval_contour_violation_ratio': np.float64(0.0659808034944537)}
18048 train {'loss_recons': 1.3066662549972534, 'loss_margin': 4.140532493591309, 'loss_contour': 0.12308994680643082, 'loss': 5.57028865814209}
18048 eval {'eval_loss_recons': np.float64(1.2427957700816132), 'eval_contour_violation_ratio': np.float64(0.053968618886142884)}
18176 train {'loss_recons': 1.2281312942504883, 'loss_margin': 1.7544808387756348, 'loss_contour': 0.20867133140563965, 'loss': 3.1912834644317627}
18176 eval {'eval_loss_recons': np.float64(1.0289948323724984), 'eval_contour_violation_ratio': np.float64(0.054888211966204956)}
18304 train {'loss_recons': 1.0480371713638306, 'loss_margin': 2.1067020893096924, 'loss_contour': 0.15724660456180573, 'loss': 3.311985969543457}
18304 eval {'eval_loss_recons': np.float64(0.9481449933430727), 'eval_contour_violation_ratio': np.float64(0.05126731421346054)}
18432 train {'loss_recons': 1.1413300037384033, 'loss_margin': 1.3033621311187744, 'loss_contour': 0.2001747041940689, 'loss': 2.644866943359375}
18432 eval {'eval_loss_recons': np.float64(1.1432609353176264), 'eval_contour_violation_ratio': np.float64(0.06471636300936835)}
18560 train {'loss_recons': 1.2483301162719727, 'loss_margin': 1.537362813949585, 'loss_contour': 0.18205134570598602, 'loss': 2.9677443504333496}
18560 eval {'eval_loss_recons': np.float64(0.9641280701106931), 'eval_contour_violation_ratio': np.float64(0.0561526524512903)}
18688 train {'loss_recons': 0.9985103011131287, 'loss_margin': 1.3036048412322998, 'loss_contour': 0.16506922245025635, 'loss': 2.467184543609619}
18688 eval {'eval_loss_recons': np.float64(0.9456742299651816), 'eval_contour_violation_ratio': np.float64(0.05580780504626703)}
18816 train {'loss_recons': 1.2328203916549683, 'loss_margin': 3.7043888568878174, 'loss_contour': 0.15437154471874237, 'loss': 5.091580867767334}
18816 eval {'eval_loss_recons': np.float64(1.13814137553045), 'eval_contour_violation_ratio': np.float64(0.05666992355882522)}
18944 train {'loss_recons': 0.9478910565376282, 'loss_margin': 0.7521976232528687, 'loss_contour': 0.35166552662849426, 'loss': 2.0517542362213135}
18944 eval {'eval_loss_recons': np.float64(0.9381596160993952), 'eval_contour_violation_ratio': np.float64(0.053968618886142884)}
19072 train {'loss_recons': 1.0651715993881226, 'loss_margin': 2.1863083839416504, 'loss_contour': 0.20236565172672272, 'loss': 3.453845739364624}
19072 eval {'eval_loss_recons': np.float64(0.9571219480434606), 'eval_contour_violation_ratio': np.float64(0.05368124604862348)}
19200 train {'loss_recons': 1.1096349954605103, 'loss_margin': 1.3484970331192017, 'loss_contour': 0.12462984770536423, 'loss': 2.582761764526367}
19200 eval {'eval_loss_recons': np.float64(0.9565913317277671), 'eval_contour_violation_ratio': np.float64(0.048221162135754926)}
19328 train {'loss_recons': 1.1673405170440674, 'loss_margin': 1.7965573072433472, 'loss_contour': 0.20761851966381073, 'loss': 3.171516180038452}
19328 eval {'eval_loss_recons': np.float64(1.1683904980016144), 'eval_contour_violation_ratio': np.float64(0.051554687050979944)}
19456 train {'loss_recons': 1.0695641040802002, 'loss_margin': 1.408764123916626, 'loss_contour': 0.17338451743125916, 'loss': 2.651712656021118}
19456 eval {'eval_loss_recons': np.float64(0.9186795327653531), 'eval_contour_violation_ratio': np.float64(0.051037415943445026)}
19584 train {'loss_recons': 1.1209959983825684, 'loss_margin': 1.5210797786712646, 'loss_contour': 0.21207274496555328, 'loss': 2.8541486263275146}
19584 eval {'eval_loss_recons': np.float64(0.9738634659638492), 'eval_contour_violation_ratio': np.float64(0.055922754181274785)}
19712 train {'loss_recons': 1.0249651670455933, 'loss_margin': 1.453291654586792, 'loss_contour': 0.1771901547908783, 'loss': 2.655446767807007}
19712 eval {'eval_loss_recons': np.float64(0.9314896200342108), 'eval_contour_violation_ratio': np.float64(0.04879590781079372)}
19840 train {'loss_recons': 0.9233195781707764, 'loss_margin': 1.921335220336914, 'loss_contour': 0.15996649861335754, 'loss': 3.0046212673187256}
19840 eval {'eval_loss_recons': np.float64(0.9111688845547895), 'eval_contour_violation_ratio': np.float64(0.0513822633484683)}
19968 train {'loss_recons': 1.1448489427566528, 'loss_margin': 2.637964963912964, 'loss_contour': 0.15409934520721436, 'loss': 3.93691349029541}
19968 eval {'eval_loss_recons': np.float64(0.909360897193882), 'eval_contour_violation_ratio': np.float64(0.047301569055692855)}
20096 train {'loss_recons': 1.113670825958252, 'loss_margin': 2.684937000274658, 'loss_contour': 0.1242179349064827, 'loss': 3.922825813293457}
20096 eval {'eval_loss_recons': np.float64(0.9875125905165761), 'eval_contour_violation_ratio': np.float64(0.05402609345364676)}
20224 train {'loss_recons': 0.8580516576766968, 'loss_margin': 0.9636136293411255, 'loss_contour': 0.1354626715183258, 'loss': 1.9571279287338257}
20224 eval {'eval_loss_recons': np.float64(0.8884581229507151), 'eval_contour_violation_ratio': np.float64(0.0488533823782976)}
20352 train {'loss_recons': 0.9526466131210327, 'loss_margin': 2.0847558975219727, 'loss_contour': 0.14254319667816162, 'loss': 3.179945945739746}
20352 eval {'eval_loss_recons': np.float64(0.9248270554335051), 'eval_contour_violation_ratio': np.float64(0.04753146732570837)}
20480 train {'loss_recons': 1.0545711517333984, 'loss_margin': 1.0272698402404785, 'loss_contour': 0.1690894365310669, 'loss': 2.2509303092956543}
20480 eval {'eval_loss_recons': np.float64(0.8940477470697812), 'eval_contour_violation_ratio': np.float64(0.0516696361859877)}
20608 train {'loss_recons': 1.0523295402526855, 'loss_margin': 1.8518600463867188, 'loss_contour': 0.17776507139205933, 'loss': 3.0819547176361084}
20608 eval {'eval_loss_recons': np.float64(0.907674411594074), 'eval_contour_violation_ratio': np.float64(0.04569228116558423)}
20736 train {'loss_recons': 1.0481078624725342, 'loss_margin': 1.4203320741653442, 'loss_contour': 0.1920444667339325, 'loss': 2.660484552383423}
20736 eval {'eval_loss_recons': np.float64(0.9053307055603173), 'eval_contour_violation_ratio': np.float64(0.054485889993677795)}
20864 train {'loss_recons': 1.0131847858428955, 'loss_margin': 1.047171711921692, 'loss_contour': 0.1702389121055603, 'loss': 2.230595588684082}
20864 eval {'eval_loss_recons': np.float64(0.9768269234381525), 'eval_contour_violation_ratio': np.float64(0.04896833151330536)}
20992 train {'loss_recons': 0.9800756573677063, 'loss_margin': 2.223130702972412, 'loss_contour': 0.18561121821403503, 'loss': 3.388817548751831}
20992 eval {'eval_loss_recons': np.float64(0.8650499649330367), 'eval_contour_violation_ratio': np.float64(0.050060348295879076)}
21120 train {'loss_recons': 0.9484667778015137, 'loss_margin': 2.5996792316436768, 'loss_contour': 0.16491293907165527, 'loss': 3.7130589485168457}
21120 eval {'eval_loss_recons': np.float64(0.8935756341056698), 'eval_contour_violation_ratio': np.float64(0.045347433760560954)}
21248 train {'loss_recons': 0.9984580278396606, 'loss_margin': 1.3437409400939941, 'loss_contour': 0.2332409769296646, 'loss': 2.575439929962158}
21248 eval {'eval_loss_recons': np.float64(0.9173213087240477), 'eval_contour_violation_ratio': np.float64(0.05172711075349158)}
21376 train {'loss_recons': 0.9881677627563477, 'loss_margin': 2.2503557205200195, 'loss_contour': 0.18987976014614105, 'loss': 3.42840313911438}
21376 eval {'eval_loss_recons': np.float64(0.859633269024899), 'eval_contour_violation_ratio': np.float64(0.046267026840623025)}
21504 train {'loss_recons': 1.010110855102539, 'loss_margin': 1.8604363203048706, 'loss_contour': 0.1301397830247879, 'loss': 3.0006871223449707}
21504 eval {'eval_loss_recons': np.float64(0.8578853429671985), 'eval_contour_violation_ratio': np.float64(0.04804873843324329)}
21632 train {'loss_recons': 1.0607142448425293, 'loss_margin': 2.7213826179504395, 'loss_contour': 0.13300944864749908, 'loss': 3.9151062965393066}
21632 eval {'eval_loss_recons': np.float64(0.9173047704859866), 'eval_contour_violation_ratio': np.float64(0.046324501408126904)}
21760 train {'loss_recons': 1.0665831565856934, 'loss_margin': 2.8873515129089355, 'loss_contour': 0.1427382230758667, 'loss': 4.096673011779785}
21760 eval {'eval_loss_recons': np.float64(0.8926367951116446), 'eval_contour_violation_ratio': np.float64(0.04919822978332088)}
21888 train {'loss_recons': 1.2314486503601074, 'loss_margin': 5.169437408447266, 'loss_contour': 0.25133395195007324, 'loss': 6.652219772338867}
21888 eval {'eval_loss_recons': np.float64(1.2028084997354114), 'eval_contour_violation_ratio': np.float64(0.04695672165066958)}
22016 train {'loss_recons': 0.9550387263298035, 'loss_margin': 1.962482213973999, 'loss_contour': 0.2741760015487671, 'loss': 3.191697120666504}
22016 eval {'eval_loss_recons': np.float64(0.8819957856453419), 'eval_contour_violation_ratio': np.float64(0.0468992470831657)}
22144 train {'loss_recons': 0.8779391050338745, 'loss_margin': 2.0255985260009766, 'loss_contour': 0.12801603972911835, 'loss': 3.0315537452697754}
22144 eval {'eval_loss_recons': np.float64(0.8487115766245312), 'eval_contour_violation_ratio': np.float64(0.04563480659808035)}
22272 train {'loss_recons': 1.1179845333099365, 'loss_margin': 2.83488392829895, 'loss_contour': 0.181535005569458, 'loss': 4.134403228759766}
22272 eval {'eval_loss_recons': np.float64(0.865578311701759), 'eval_contour_violation_ratio': np.float64(0.049140755215817004)}
22400 train {'loss_recons': 0.9258683919906616, 'loss_margin': 1.4962551593780518, 'loss_contour': 0.19762293994426727, 'loss': 2.619746446609497}
22400 eval {'eval_loss_recons': np.float64(0.8871425369486339), 'eval_contour_violation_ratio': np.float64(0.04781884016322777)}
22528 train {'loss_recons': 0.9598208665847778, 'loss_margin': 2.9680404663085938, 'loss_contour': 0.307838499546051, 'loss': 4.235699653625488}
22528 eval {'eval_loss_recons': np.float64(0.9135988506578363), 'eval_contour_violation_ratio': np.float64(0.04574975573308811)}
22656 train {'loss_recons': 0.9436209797859192, 'loss_margin': 1.2279549837112427, 'loss_contour': 0.17359095811843872, 'loss': 2.3451669216156006}
22656 eval {'eval_loss_recons': np.float64(0.8242815984470385), 'eval_contour_violation_ratio': np.float64(0.04615207770561527)}
22784 train {'loss_recons': 1.0637000799179077, 'loss_margin': 3.681978225708008, 'loss_contour': 0.1569857895374298, 'loss': 4.9026641845703125}
22784 eval {'eval_loss_recons': np.float64(0.9019813656215947), 'eval_contour_violation_ratio': np.float64(0.04477268808552216)}
22912 train {'loss_recons': 1.143710970878601, 'loss_margin': 3.514219045639038, 'loss_contour': 0.15177400410175323, 'loss': 4.809703826904297}
22912 eval {'eval_loss_recons': np.float64(0.9052485810871275), 'eval_contour_violation_ratio': np.float64(0.04781884016322777)}
23040 train {'loss_recons': 1.1885077953338623, 'loss_margin': 3.840878486633301, 'loss_contour': 0.21370401978492737, 'loss': 5.243090629577637}
23040 eval {'eval_loss_recons': np.float64(0.9773015006738298), 'eval_contour_violation_ratio': np.float64(0.04258865452037473)}
23168 train {'loss_recons': 0.9450966119766235, 'loss_margin': 1.6974406242370605, 'loss_contour': 0.14452821016311646, 'loss': 2.787065267562866}
23168 eval {'eval_loss_recons': np.float64(0.8002398096046871), 'eval_contour_violation_ratio': np.float64(0.04454278981550664)}
23296 train {'loss_recons': 1.0097821950912476, 'loss_margin': 2.997694730758667, 'loss_contour': 0.15793676674365997, 'loss': 4.1654133796691895}
23296 eval {'eval_loss_recons': np.float64(0.8223074692032507), 'eval_contour_violation_ratio': np.float64(0.04753146732570837)}
23424 train {'loss_recons': 0.9450689554214478, 'loss_margin': 1.3263399600982666, 'loss_contour': 0.15651839971542358, 'loss': 2.4279274940490723}
23424 eval {'eval_loss_recons': np.float64(0.8215691146530453), 'eval_contour_violation_ratio': np.float64(0.047301569055692855)}
23552 train {'loss_recons': 0.9542986750602722, 'loss_margin': 3.899844169616699, 'loss_contour': 0.19550521671772003, 'loss': 5.049647808074951}
23552 eval {'eval_loss_recons': np.float64(0.8226280408399647), 'eval_contour_violation_ratio': np.float64(0.04161158687280878)}
23680 train {'loss_recons': 0.9341946840286255, 'loss_margin': 1.3607354164123535, 'loss_contour': 0.11952240765094757, 'loss': 2.414452314376831}
23680 eval {'eval_loss_recons': np.float64(0.7823443299411782), 'eval_contour_violation_ratio': np.float64(0.04431289154549112)}
23808 train {'loss_recons': 0.9493896961212158, 'loss_margin': 3.229055881500244, 'loss_contour': 0.17175544798374176, 'loss': 4.35020112991333}
23808 eval {'eval_loss_recons': np.float64(0.8328937260645618), 'eval_contour_violation_ratio': np.float64(0.04327834933042129)}
23936 train {'loss_recons': 0.7845777273178101, 'loss_margin': 1.0845921039581299, 'loss_contour': 0.14389272034168243, 'loss': 2.0130624771118164}
23936 eval {'eval_loss_recons': np.float64(0.8109528490481983), 'eval_contour_violation_ratio': np.float64(0.045232484625553196)}
24064 train {'loss_recons': 0.8414525389671326, 'loss_margin': 2.768465042114258, 'loss_contour': 0.20255620777606964, 'loss': 3.812473773956299}
24064 eval {'eval_loss_recons': np.float64(0.7875560337986056), 'eval_contour_violation_ratio': np.float64(0.04333582389792517)}
24192 train {'loss_recons': 1.1660029888153076, 'loss_margin': 2.130162239074707, 'loss_contour': 0.12868547439575195, 'loss': 3.4248507022857666}
24192 eval {'eval_loss_recons': np.float64(0.8157183887486059), 'eval_contour_violation_ratio': np.float64(0.0468992470831657)}
24320 train {'loss_recons': 0.9100281596183777, 'loss_margin': 2.550835609436035, 'loss_contour': 0.18324069678783417, 'loss': 3.644104480743408}
24320 eval {'eval_loss_recons': np.float64(0.7916396028237089), 'eval_contour_violation_ratio': np.float64(0.04281855279039025)}
24448 train {'loss_recons': 0.8838369250297546, 'loss_margin': 2.0397541522979736, 'loss_contour': 0.13840006291866302, 'loss': 3.0619912147521973}
24448 eval {'eval_loss_recons': np.float64(0.7725911552434231), 'eval_contour_violation_ratio': np.float64(0.04166906144031266)}
24576 train {'loss_recons': 0.9986376166343689, 'loss_margin': 2.67979097366333, 'loss_contour': 0.18206705152988434, 'loss': 3.8604958057403564}
24576 eval {'eval_loss_recons': np.float64(0.8088848013990605), 'eval_contour_violation_ratio': np.float64(0.04086441749525835)}
24704 train {'loss_recons': 0.9958440065383911, 'loss_margin': 0.6559083461761475, 'loss_contour': 0.13893964886665344, 'loss': 1.7906919717788696}
24704 eval {'eval_loss_recons': np.float64(0.8397706909222872), 'eval_contour_violation_ratio': np.float64(0.05034772113339847)}
24832 train {'loss_recons': 0.9043521881103516, 'loss_margin': 1.9099814891815186, 'loss_contour': 0.1526445597410202, 'loss': 2.9669783115386963}
24832 eval {'eval_loss_recons': np.float64(0.7980727645844505), 'eval_contour_violation_ratio': np.float64(0.042013908845335936)}
24960 train {'loss_recons': 0.9567673206329346, 'loss_margin': 1.90824556350708, 'loss_contour': 0.1556771993637085, 'loss': 3.0206899642944336}
24960 eval {'eval_loss_recons': np.float64(0.808339263475139), 'eval_contour_violation_ratio': np.float64(0.041554112305304904)}
25088 train {'loss_recons': 0.8595598340034485, 'loss_margin': 1.6130316257476807, 'loss_contour': 0.12249446660280228, 'loss': 2.595085859298706}
25088 eval {'eval_loss_recons': np.float64(0.7690997096659135), 'eval_contour_violation_ratio': np.float64(0.04120926490028162)}
25216 train {'loss_recons': 1.0502488613128662, 'loss_margin': 2.498896598815918, 'loss_contour': 0.15951356291770935, 'loss': 3.7086589336395264}
25216 eval {'eval_loss_recons': np.float64(0.8265263188377968), 'eval_contour_violation_ratio': np.float64(0.04086441749525835)}
25344 train {'loss_recons': 0.9705764055252075, 'loss_margin': 1.8724970817565918, 'loss_contour': 0.13302692770957947, 'loss': 2.976100206375122}
25344 eval {'eval_loss_recons': np.float64(0.7899697931431553), 'eval_contour_violation_ratio': np.float64(0.04574975573308811)}
25472 train {'loss_recons': 0.8534334897994995, 'loss_margin': 1.1074998378753662, 'loss_contour': 0.17863275110721588, 'loss': 2.13956618309021}
25472 eval {'eval_loss_recons': np.float64(0.8223503135722751), 'eval_contour_violation_ratio': np.float64(0.04264612908787861)}
25600 train {'loss_recons': 0.9075790047645569, 'loss_margin': 2.013155460357666, 'loss_contour': 0.13531741499900818, 'loss': 3.056051731109619}
25600 eval {'eval_loss_recons': np.float64(0.7609776442701186), 'eval_contour_violation_ratio': np.float64(0.039082705902638085)}
25728 train {'loss_recons': 0.8628547787666321, 'loss_margin': 2.338552474975586, 'loss_contour': 0.30654841661453247, 'loss': 3.507955551147461}
25728 eval {'eval_loss_recons': np.float64(0.8098305165893007), 'eval_contour_violation_ratio': np.float64(0.04414046784297948)}
25856 train {'loss_recons': 0.8319805860519409, 'loss_margin': 1.9668463468551636, 'loss_contour': 0.13111591339111328, 'loss': 2.9299428462982178}
25856 eval {'eval_loss_recons': np.float64(0.7537248640292091), 'eval_contour_violation_ratio': np.float64(0.0412667394677855)}
25984 train {'loss_recons': 1.0140661001205444, 'loss_margin': 3.5638856887817383, 'loss_contour': 0.17878612875938416, 'loss': 4.756738185882568}
25984 eval {'eval_loss_recons': np.float64(0.7747945753536927), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
26112 train {'loss_recons': 0.98633873462677, 'loss_margin': 0.7561648488044739, 'loss_contour': 0.1630532294511795, 'loss': 1.9055569171905518}
26112 eval {'eval_loss_recons': np.float64(0.8465386009822901), 'eval_contour_violation_ratio': np.float64(0.047761365595723894)}
26240 train {'loss_recons': 0.8022197484970093, 'loss_margin': 1.728402853012085, 'loss_contour': 0.16002723574638367, 'loss': 2.6906497478485107}
26240 eval {'eval_loss_recons': np.float64(0.7828332563377), 'eval_contour_violation_ratio': np.float64(0.04063451922524283)}
26368 train {'loss_recons': 0.867975115776062, 'loss_margin': 3.3943424224853516, 'loss_contour': 0.16787786781787872, 'loss': 4.430195331573486}
26368 eval {'eval_loss_recons': np.float64(0.7683345841333724), 'eval_contour_violation_ratio': np.float64(0.042473705385366975)}
26496 train {'loss_recons': 0.867403507232666, 'loss_margin': 1.8597921133041382, 'loss_contour': 0.20748451352119446, 'loss': 2.934680223464966}
26496 eval {'eval_loss_recons': np.float64(0.776973804728558), 'eval_contour_violation_ratio': np.float64(0.04270360365538249)}
26624 train {'loss_recons': 0.8160245418548584, 'loss_margin': 2.319610118865967, 'loss_contour': 0.16384920477867126, 'loss': 3.2994837760925293}
26624 eval {'eval_loss_recons': np.float64(0.7426196283154805), 'eval_contour_violation_ratio': np.float64(0.04023219725271567)}
26752 train {'loss_recons': 0.8377318382263184, 'loss_margin': 2.9254889488220215, 'loss_contour': 0.21231400966644287, 'loss': 3.9755349159240723}
26752 eval {'eval_loss_recons': np.float64(0.7564689881713419), 'eval_contour_violation_ratio': np.float64(0.04299097649290189)}
26880 train {'loss_recons': 0.8063833713531494, 'loss_margin': 1.741051197052002, 'loss_contour': 0.16496697068214417, 'loss': 2.7124016284942627}
26880 eval {'eval_loss_recons': np.float64(0.736676636675626), 'eval_contour_violation_ratio': np.float64(0.04086441749525835)}
27008 train {'loss_recons': 1.0866925716400146, 'loss_margin': 2.452474594116211, 'loss_contour': 0.29913103580474854, 'loss': 3.8382983207702637}
27008 eval {'eval_loss_recons': np.float64(0.8457578923739111), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
27136 train {'loss_recons': 0.9110718965530396, 'loss_margin': 2.0068132877349854, 'loss_contour': 0.1760464310646057, 'loss': 3.0939316749572754}
27136 eval {'eval_loss_recons': np.float64(0.7471962152497187), 'eval_contour_violation_ratio': np.float64(0.04178401057532042)}
27264 train {'loss_recons': 0.9470855593681335, 'loss_margin': 3.4112143516540527, 'loss_contour': 0.14225925505161285, 'loss': 4.500558853149414}
27264 eval {'eval_loss_recons': np.float64(0.7538728367039733), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
27392 train {'loss_recons': 0.8407419323921204, 'loss_margin': 3.6036548614501953, 'loss_contour': 0.16174207627773285, 'loss': 4.606139183044434}
27392 eval {'eval_loss_recons': np.float64(0.7320583200390853), 'eval_contour_violation_ratio': np.float64(0.038622909362607045)}
27520 train {'loss_recons': 0.8539357781410217, 'loss_margin': 1.7102768421173096, 'loss_contour': 0.17870067059993744, 'loss': 2.742913246154785}
27520 eval {'eval_loss_recons': np.float64(0.7575926170546844), 'eval_contour_violation_ratio': np.float64(0.04425541697798724)}
27648 train {'loss_recons': 0.9729718565940857, 'loss_margin': 3.546708822250366, 'loss_contour': 0.1502133011817932, 'loss': 4.669893741607666}
27648 eval {'eval_loss_recons': np.float64(0.7887343466689769), 'eval_contour_violation_ratio': np.float64(0.042876027357894135)}
27776 train {'loss_recons': 0.8715343475341797, 'loss_margin': 3.6260228157043457, 'loss_contour': 0.16824902594089508, 'loss': 4.665806293487549}
27776 eval {'eval_loss_recons': np.float64(0.7133211285760881), 'eval_contour_violation_ratio': np.float64(0.04120926490028162)}
27904 train {'loss_recons': 0.8650168180465698, 'loss_margin': 1.6180516481399536, 'loss_contour': 0.1552726924419403, 'loss': 2.638341188430786}
27904 eval {'eval_loss_recons': np.float64(0.6855506426597399), 'eval_contour_violation_ratio': np.float64(0.03977240071268464)}
28032 train {'loss_recons': 0.9382871389389038, 'loss_margin': 2.9890079498291016, 'loss_contour': 0.21488280594348907, 'loss': 4.142178058624268}
28032 eval {'eval_loss_recons': np.float64(0.771062188290591), 'eval_contour_violation_ratio': np.float64(0.040979366630266106)}
28160 train {'loss_recons': 0.7018354535102844, 'loss_margin': 2.171731948852539, 'loss_contour': 0.14357125759124756, 'loss': 3.0171384811401367}
28160 eval {'eval_loss_recons': np.float64(0.6873656652715264), 'eval_contour_violation_ratio': np.float64(0.04023219725271567)}
28288 train {'loss_recons': 0.8383088707923889, 'loss_margin': 2.287722587585449, 'loss_contour': 0.15934626758098602, 'loss': 3.2853777408599854}
28288 eval {'eval_loss_recons': np.float64(0.707867306873343), 'eval_contour_violation_ratio': np.float64(0.04178401057532042)}
28416 train {'loss_recons': 0.9225248098373413, 'loss_margin': 3.404825210571289, 'loss_contour': 0.15436893701553345, 'loss': 4.481719017028809}
28416 eval {'eval_loss_recons': np.float64(0.8776055441961758), 'eval_contour_violation_ratio': np.float64(0.042473705385366975)}
28544 train {'loss_recons': 0.746706485748291, 'loss_margin': 1.8464640378952026, 'loss_contour': 0.16855674982070923, 'loss': 2.7617273330688477}
28544 eval {'eval_loss_recons': np.float64(0.725128826811331), 'eval_contour_violation_ratio': np.float64(0.04316340019541353)}
28672 train {'loss_recons': 0.8754072785377502, 'loss_margin': 2.1887784004211426, 'loss_contour': 0.137399360537529, 'loss': 3.201585054397583}
28672 eval {'eval_loss_recons': np.float64(0.6851939081121852), 'eval_contour_violation_ratio': np.float64(0.042473705385366975)}
28800 train {'loss_recons': 0.8001954555511475, 'loss_margin': 2.295541286468506, 'loss_contour': 0.18404385447502136, 'loss': 3.279780626296997}
28800 eval {'eval_loss_recons': np.float64(0.6772646047045875), 'eval_contour_violation_ratio': np.float64(0.04115179033277774)}
28928 train {'loss_recons': 0.719992995262146, 'loss_margin': 1.2735748291015625, 'loss_contour': 0.20798686146736145, 'loss': 2.201554775238037}
28928 eval {'eval_loss_recons': np.float64(0.6638826830435823), 'eval_contour_violation_ratio': np.float64(0.040519570090235074)}
29056 train {'loss_recons': 0.9249236583709717, 'loss_margin': 4.8051252365112305, 'loss_contour': 0.20364929735660553, 'loss': 5.933698654174805}
29056 eval {'eval_loss_recons': np.float64(0.7243027728391961), 'eval_contour_violation_ratio': np.float64(0.04299097649290189)}
29184 train {'loss_recons': 0.7469742298126221, 'loss_margin': 1.3648159503936768, 'loss_contour': 0.12271567434072495, 'loss': 2.2345058917999268}
29184 eval {'eval_loss_recons': np.float64(0.6621154193505331), 'eval_contour_violation_ratio': np.float64(0.04161158687280878)}
29312 train {'loss_recons': 0.8333433270454407, 'loss_margin': 3.212400436401367, 'loss_contour': 0.19435977935791016, 'loss': 4.240103721618652}
29312 eval {'eval_loss_recons': np.float64(0.7549421849869188), 'eval_contour_violation_ratio': np.float64(0.04080694292775447)}
29440 train {'loss_recons': 0.7721454501152039, 'loss_margin': 1.693027377128601, 'loss_contour': 0.18032242357730865, 'loss': 2.6454951763153076}
29440 eval {'eval_loss_recons': np.float64(0.6693696048868113), 'eval_contour_violation_ratio': np.float64(0.04166906144031266)}
29568 train {'loss_recons': 0.8078609704971313, 'loss_margin': 3.6507530212402344, 'loss_contour': 0.11335721611976624, 'loss': 4.5719709396362305}
29568 eval {'eval_loss_recons': np.float64(0.6912918649268383), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
29696 train {'loss_recons': 0.7675472497940063, 'loss_margin': 1.2494724988937378, 'loss_contour': 0.165517196059227, 'loss': 2.1825368404388428}
29696 eval {'eval_loss_recons': np.float64(0.6650855086033125), 'eval_contour_violation_ratio': np.float64(0.04120926490028162)}
29824 train {'loss_recons': 0.7205759286880493, 'loss_margin': 3.2058300971984863, 'loss_contour': 0.20072238147258759, 'loss': 4.1271281242370605}
29824 eval {'eval_loss_recons': np.float64(0.6728358752982779), 'eval_contour_violation_ratio': np.float64(0.04172653600781654)}
29952 train {'loss_recons': 0.8062818050384521, 'loss_margin': 2.4951300621032715, 'loss_contour': 0.18227751553058624, 'loss': 3.483689308166504}
29952 eval {'eval_loss_recons': np.float64(0.7065587079129203), 'eval_contour_violation_ratio': np.float64(0.03793321455256049)}
30080 train {'loss_recons': 0.8434942364692688, 'loss_margin': 1.8777813911437988, 'loss_contour': 0.14608190953731537, 'loss': 2.8673574924468994}
30080 eval {'eval_loss_recons': np.float64(0.6841767570235954), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
30208 train {'loss_recons': 0.9034809470176697, 'loss_margin': 2.8367483615875244, 'loss_contour': 0.15631221234798431, 'loss': 3.8965415954589844}
30208 eval {'eval_loss_recons': np.float64(0.6601498106183964), 'eval_contour_violation_ratio': np.float64(0.03925512960514972)}
30336 train {'loss_recons': 0.7814399003982544, 'loss_margin': 2.6442999839782715, 'loss_contour': 0.1563834846019745, 'loss': 3.582123279571533}
30336 eval {'eval_loss_recons': np.float64(0.7924670509576167), 'eval_contour_violation_ratio': np.float64(0.04086441749525835)}
30464 train {'loss_recons': 0.7850199341773987, 'loss_margin': 3.4512526988983154, 'loss_contour': 0.18674583733081818, 'loss': 4.423018455505371}
30464 eval {'eval_loss_recons': np.float64(0.6611728880999068), 'eval_contour_violation_ratio': np.float64(0.039140180470141964)}
30592 train {'loss_recons': 0.8458652496337891, 'loss_margin': 2.599595785140991, 'loss_contour': 0.1710892617702484, 'loss': 3.6165502071380615}
30592 eval {'eval_loss_recons': np.float64(0.6717561268234705), 'eval_contour_violation_ratio': np.float64(0.03839301109259153)}
30720 train {'loss_recons': 0.7679896354675293, 'loss_margin': 2.055030107498169, 'loss_contour': 0.14362786710262299, 'loss': 2.9666476249694824}
30720 eval {'eval_loss_recons': np.float64(0.7328871520958472), 'eval_contour_violation_ratio': np.float64(0.03879533306511868)}
30848 train {'loss_recons': 0.767282247543335, 'loss_margin': 2.2917423248291016, 'loss_contour': 0.13949567079544067, 'loss': 3.1985201835632324}
30848 eval {'eval_loss_recons': np.float64(0.6640571607962171), 'eval_contour_violation_ratio': np.float64(0.037703316282544974)}
30976 train {'loss_recons': 0.7607184648513794, 'loss_margin': 4.0083465576171875, 'loss_contour': 0.17712818086147308, 'loss': 4.946193218231201}
30976 eval {'eval_loss_recons': np.float64(0.6791423409146913), 'eval_contour_violation_ratio': np.float64(0.04310592562790965)}
31104 train {'loss_recons': 0.8994998931884766, 'loss_margin': 1.7345553636550903, 'loss_contour': 0.13345667719841003, 'loss': 2.7675118446350098}
31104 eval {'eval_loss_recons': np.float64(0.6934326959987238), 'eval_contour_violation_ratio': np.float64(0.04506006092304155)}
31232 train {'loss_recons': 0.7985763549804688, 'loss_margin': 1.57435941696167, 'loss_contour': 0.1433037370443344, 'loss': 2.5162394046783447}
31232 eval {'eval_loss_recons': np.float64(0.6549556120420423), 'eval_contour_violation_ratio': np.float64(0.03954250244266912)}
31360 train {'loss_recons': 0.8633459806442261, 'loss_margin': 3.307953357696533, 'loss_contour': 0.1497737318277359, 'loss': 4.321073055267334}
31360 eval {'eval_loss_recons': np.float64(0.6559037538174419), 'eval_contour_violation_ratio': np.float64(0.04057704465773895)}
31488 train {'loss_recons': 0.7225067019462585, 'loss_margin': 3.8213682174682617, 'loss_contour': 0.31265634298324585, 'loss': 4.856531143188477}
31488 eval {'eval_loss_recons': np.float64(0.6496366331093071), 'eval_contour_violation_ratio': np.float64(0.040002298982700156)}
31616 train {'loss_recons': 0.7775630950927734, 'loss_margin': 1.664896845817566, 'loss_contour': 0.17061688005924225, 'loss': 2.61307692527771}
31616 eval {'eval_loss_recons': np.float64(0.6345552295097415), 'eval_contour_violation_ratio': np.float64(0.04189895971032818)}
31744 train {'loss_recons': 0.757402777671814, 'loss_margin': 2.288754463195801, 'loss_contour': 0.19236353039741516, 'loss': 3.238520860671997}
31744 eval {'eval_loss_recons': np.float64(0.6625090012023637), 'eval_contour_violation_ratio': np.float64(0.04057704465773895)}
31872 train {'loss_recons': 0.7346005439758301, 'loss_margin': 2.4428486824035645, 'loss_contour': 0.15705615282058716, 'loss': 3.334505319595337}
31872 eval {'eval_loss_recons': np.float64(0.6270113222777236), 'eval_contour_violation_ratio': np.float64(0.04057704465773895)}
32000 train {'loss_recons': 0.6126962900161743, 'loss_margin': 2.7061824798583984, 'loss_contour': 0.16537825763225555, 'loss': 3.4842569828033447}
32000 eval {'eval_loss_recons': np.float64(0.6435709611283714), 'eval_contour_violation_ratio': np.float64(0.039082705902638085)}
32128 train {'loss_recons': 0.7420548796653748, 'loss_margin': 3.2245473861694336, 'loss_contour': 0.1735914945602417, 'loss': 4.140193939208984}
32128 eval {'eval_loss_recons': np.float64(0.861138782687345), 'eval_contour_violation_ratio': np.float64(0.039025231335134206)}
32256 train {'loss_recons': 0.6672500371932983, 'loss_margin': 1.5675303936004639, 'loss_contour': 0.17418226599693298, 'loss': 2.4089624881744385}
32256 eval {'eval_loss_recons': np.float64(0.6434937096688691), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
32384 train {'loss_recons': 0.7850550413131714, 'loss_margin': 3.6147263050079346, 'loss_contour': 0.13937516510486603, 'loss': 4.539156436920166}
32384 eval {'eval_loss_recons': np.float64(0.7235565634273636), 'eval_contour_violation_ratio': np.float64(0.04028967182021955)}
32512 train {'loss_recons': 0.6768723726272583, 'loss_margin': 0.8617371320724487, 'loss_contour': 0.1193714290857315, 'loss': 1.6579809188842773}
32512 eval {'eval_loss_recons': np.float64(0.6765100967398869), 'eval_contour_violation_ratio': np.float64(0.040002298982700156)}
32640 train {'loss_recons': 0.8252722024917603, 'loss_margin': 4.075291633605957, 'loss_contour': 0.14219176769256592, 'loss': 5.042755603790283}
32640 eval {'eval_loss_recons': np.float64(0.6808809508691371), 'eval_contour_violation_ratio': np.float64(0.03563423185240531)}
32768 train {'loss_recons': 0.7940720319747925, 'loss_margin': 1.0631506443023682, 'loss_contour': 0.1795850396156311, 'loss': 2.0368077754974365}
32768 eval {'eval_loss_recons': np.float64(0.6594237908394615), 'eval_contour_violation_ratio': np.float64(0.04299097649290189)}
32896 train {'loss_recons': 0.7131210565567017, 'loss_margin': 2.4887495040893555, 'loss_contour': 0.14798159897327423, 'loss': 3.3498520851135254}
32896 eval {'eval_loss_recons': np.float64(0.6775291568422549), 'eval_contour_violation_ratio': np.float64(0.04069199379274671)}
33024 train {'loss_recons': 1.0612807273864746, 'loss_margin': 3.911245346069336, 'loss_contour': 0.2688937485218048, 'loss': 5.241419792175293}
33024 eval {'eval_loss_recons': np.float64(0.8070513128583036), 'eval_contour_violation_ratio': np.float64(0.03804816368756825)}
33152 train {'loss_recons': 0.7592473030090332, 'loss_margin': 2.248065948486328, 'loss_contour': 0.17243371903896332, 'loss': 3.1797468662261963}
33152 eval {'eval_loss_recons': np.float64(0.6602335777546893), 'eval_contour_violation_ratio': np.float64(0.039599977010172996)}
33280 train {'loss_recons': 0.7502052783966064, 'loss_margin': 4.204405784606934, 'loss_contour': 0.15387025475502014, 'loss': 5.108480930328369}
33280 eval {'eval_loss_recons': np.float64(0.6480253042114038), 'eval_contour_violation_ratio': np.float64(0.04120926490028162)}
33408 train {'loss_recons': 0.7195385694503784, 'loss_margin': 3.0087621212005615, 'loss_contour': 0.16953220963478088, 'loss': 3.8978328704833984}
33408 eval {'eval_loss_recons': np.float64(0.6245175077450913), 'eval_contour_violation_ratio': np.float64(0.039082705902638085)}
33536 train {'loss_recons': 0.800926685333252, 'loss_margin': 2.306256055831909, 'loss_contour': 0.14145605266094208, 'loss': 3.248638868331909}
33536 eval {'eval_loss_recons': np.float64(0.7304308979899388), 'eval_contour_violation_ratio': np.float64(0.04074946836025059)}
33664 train {'loss_recons': 0.9294177889823914, 'loss_margin': 2.8648223876953125, 'loss_contour': 0.20824562013149261, 'loss': 4.002485752105713}
33664 eval {'eval_loss_recons': np.float64(0.646482597954049), 'eval_contour_violation_ratio': np.float64(0.03816311282257601)}
33792 train {'loss_recons': 0.7708823680877686, 'loss_margin': 2.629082441329956, 'loss_contour': 0.13494449853897095, 'loss': 3.534909248352051}
33792 eval {'eval_loss_recons': np.float64(0.6540897767222869), 'eval_contour_violation_ratio': np.float64(0.04086441749525835)}
33920 train {'loss_recons': 0.8993732929229736, 'loss_margin': 2.5713024139404297, 'loss_contour': 0.17589621245861053, 'loss': 3.6465718746185303}
33920 eval {'eval_loss_recons': np.float64(0.7316555850277697), 'eval_contour_violation_ratio': np.float64(0.04120926490028162)}
34048 train {'loss_recons': 0.7564306855201721, 'loss_margin': 4.324695110321045, 'loss_contour': 0.150887593626976, 'loss': 5.23201322555542}
34048 eval {'eval_loss_recons': np.float64(0.6217217714343116), 'eval_contour_violation_ratio': np.float64(0.039082705902638085)}
34176 train {'loss_recons': 0.8138915300369263, 'loss_margin': 2.250656843185425, 'loss_contour': 0.10438936948776245, 'loss': 3.168937921524048}
34176 eval {'eval_loss_recons': np.float64(0.7482891181186401), 'eval_contour_violation_ratio': np.float64(0.04040462095522731)}
34304 train {'loss_recons': 0.6911086440086365, 'loss_margin': 3.071108102798462, 'loss_contour': 0.22160887718200684, 'loss': 3.98382568359375}
34304 eval {'eval_loss_recons': np.float64(0.6354568691744812), 'eval_contour_violation_ratio': np.float64(0.04195643427783206)}
34432 train {'loss_recons': 0.9208276271820068, 'loss_margin': 3.242274761199951, 'loss_contour': 0.20712025463581085, 'loss': 4.370222568511963}
34432 eval {'eval_loss_recons': np.float64(0.6786684980841422), 'eval_contour_violation_ratio': np.float64(0.04120926490028162)}
34560 train {'loss_recons': 0.8496302962303162, 'loss_margin': 3.1918745040893555, 'loss_contour': 0.18894116580486298, 'loss': 4.230445861816406}
34560 eval {'eval_loss_recons': np.float64(0.6292008303502716), 'eval_contour_violation_ratio': np.float64(0.03937007874015748)}
34688 train {'loss_recons': 0.7658237814903259, 'loss_margin': 2.089113473892212, 'loss_contour': 0.1776224672794342, 'loss': 3.032559871673584}
34688 eval {'eval_loss_recons': np.float64(0.6196559322975356), 'eval_contour_violation_ratio': np.float64(0.03977240071268464)}
34816 train {'loss_recons': 0.8303380608558655, 'loss_margin': 3.9777297973632812, 'loss_contour': 0.21618208289146423, 'loss': 5.024250030517578}
34816 eval {'eval_loss_recons': np.float64(0.6730273997302364), 'eval_contour_violation_ratio': np.float64(0.04034714638772343)}
34944 train {'loss_recons': 0.7120656967163086, 'loss_margin': 1.8228180408477783, 'loss_contour': 0.15638481080532074, 'loss': 2.6912684440612793}
34944 eval {'eval_loss_recons': np.float64(0.6127679465292064), 'eval_contour_violation_ratio': np.float64(0.039140180470141964)}
35072 train {'loss_recons': 0.9127470254898071, 'loss_margin': 5.399087905883789, 'loss_contour': 0.1441567838191986, 'loss': 6.455991744995117}
35072 eval {'eval_loss_recons': np.float64(0.6275735309298756), 'eval_contour_violation_ratio': np.float64(0.03942755330766136)}
35200 train {'loss_recons': 0.7822123765945435, 'loss_margin': 2.8436501026153564, 'loss_contour': 0.15071330964565277, 'loss': 3.776575803756714}
35200 eval {'eval_loss_recons': np.float64(0.6285511710040312), 'eval_contour_violation_ratio': np.float64(0.03896775676763033)}
35328 train {'loss_recons': 0.8739680051803589, 'loss_margin': 4.5567450523376465, 'loss_contour': 0.17792809009552002, 'loss': 5.608641147613525}
35328 eval {'eval_loss_recons': np.float64(0.6877943266165054), 'eval_contour_violation_ratio': np.float64(0.03850796022759929)}
35456 train {'loss_recons': 0.7774531245231628, 'loss_margin': 4.249077796936035, 'loss_contour': 0.17496244609355927, 'loss': 5.201493263244629}
35456 eval {'eval_loss_recons': np.float64(0.6521135866601635), 'eval_contour_violation_ratio': np.float64(0.03919765503764584)}
35584 train {'loss_recons': 0.7263244986534119, 'loss_margin': 2.9941024780273438, 'loss_contour': 0.14143308997154236, 'loss': 3.8618600368499756}
35584 eval {'eval_loss_recons': np.float64(0.6132723945382592), 'eval_contour_violation_ratio': np.float64(0.03948502787516524)}
35712 train {'loss_recons': 0.6919716596603394, 'loss_margin': 4.266190052032471, 'loss_contour': 0.2008948177099228, 'loss': 5.159056663513184}
35712 eval {'eval_loss_recons': np.float64(0.6274914008433377), 'eval_contour_violation_ratio': np.float64(0.03741594344502558)}
35840 train {'loss_recons': 0.658806324005127, 'loss_margin': 1.5813604593276978, 'loss_contour': 0.1796618551015854, 'loss': 2.419828414916992}
35840 eval {'eval_loss_recons': np.float64(0.6033397619501649), 'eval_contour_violation_ratio': np.float64(0.04080694292775447)}
35968 train {'loss_recons': 0.744677722454071, 'loss_margin': 5.092397212982178, 'loss_contour': 0.2609950006008148, 'loss': 6.098069667816162}
35968 eval {'eval_loss_recons': np.float64(0.630919572247624), 'eval_contour_violation_ratio': np.float64(0.03919765503764584)}
36096 train {'loss_recons': 0.8201198577880859, 'loss_margin': 2.989159345626831, 'loss_contour': 0.11997663229703903, 'loss': 3.929255723953247}
36096 eval {'eval_loss_recons': np.float64(0.6446852418207693), 'eval_contour_violation_ratio': np.float64(0.04023219725271567)}
36224 train {'loss_recons': 0.7047333121299744, 'loss_margin': 4.224758148193359, 'loss_contour': 0.13754937052726746, 'loss': 5.067040920257568}
36224 eval {'eval_loss_recons': np.float64(0.6492287189057141), 'eval_contour_violation_ratio': np.float64(0.03689867233749066)}
36352 train {'loss_recons': 0.674507737159729, 'loss_margin': 2.7131824493408203, 'loss_contour': 0.1589028239250183, 'loss': 3.546592950820923}
36352 eval {'eval_loss_recons': np.float64(0.6300570861383238), 'eval_contour_violation_ratio': np.float64(0.04115179033277774)}
36480 train {'loss_recons': 0.7422845363616943, 'loss_margin': 2.405223846435547, 'loss_contour': 0.24078768491744995, 'loss': 3.388296127319336}
36480 eval {'eval_loss_recons': np.float64(0.6565401112752061), 'eval_contour_violation_ratio': np.float64(0.039657451577676875)}
36608 train {'loss_recons': 0.7513593435287476, 'loss_margin': 3.0027265548706055, 'loss_contour': 0.1766522377729416, 'loss': 3.9307382106781006}
36608 eval {'eval_loss_recons': np.float64(0.6402538074391327), 'eval_contour_violation_ratio': np.float64(0.04132421403528939)}
36736 train {'loss_recons': 0.7836941480636597, 'loss_margin': 4.140346050262451, 'loss_contour': 0.24601396918296814, 'loss': 5.1700544357299805}
36736 eval {'eval_loss_recons': np.float64(0.6228706387590097), 'eval_contour_violation_ratio': np.float64(0.0421288579803437)}
36864 train {'loss_recons': 0.8960181474685669, 'loss_margin': 2.264399528503418, 'loss_contour': 0.15429085493087769, 'loss': 3.3147084712982178}
36864 eval {'eval_loss_recons': np.float64(0.6058487893964026), 'eval_contour_violation_ratio': np.float64(0.039025231335134206)}
36992 train {'loss_recons': 0.7530025243759155, 'loss_margin': 5.278529167175293, 'loss_contour': 0.1134989783167839, 'loss': 6.145030975341797}
36992 eval {'eval_loss_recons': np.float64(0.6232538139220263), 'eval_contour_violation_ratio': np.float64(0.03810563825507213)}
37120 train {'loss_recons': 0.7206981778144836, 'loss_margin': 2.502913475036621, 'loss_contour': 0.15035957098007202, 'loss': 3.3739712238311768}
37120 eval {'eval_loss_recons': np.float64(0.6266340746630179), 'eval_contour_violation_ratio': np.float64(0.039025231335134206)}
37248 train {'loss_recons': 0.6602766513824463, 'loss_margin': 2.5157766342163086, 'loss_contour': 0.18076787889003754, 'loss': 3.356821060180664}
37248 eval {'eval_loss_recons': np.float64(0.6017519522696483), 'eval_contour_violation_ratio': np.float64(0.039599977010172996)}
37376 train {'loss_recons': 0.8503926992416382, 'loss_margin': 5.29062557220459, 'loss_contour': 0.23824648559093475, 'loss': 6.379264831542969}
37376 eval {'eval_loss_recons': np.float64(0.6686074666998464), 'eval_contour_violation_ratio': np.float64(0.04028967182021955)}
37504 train {'loss_recons': 0.6777288317680359, 'loss_margin': 3.604384660720825, 'loss_contour': 0.184226855635643, 'loss': 4.4663405418396}
37504 eval {'eval_loss_recons': np.float64(0.6487713759480178), 'eval_contour_violation_ratio': np.float64(0.038737858497614804)}
37632 train {'loss_recons': 0.7127417325973511, 'loss_margin': 4.697469234466553, 'loss_contour': 0.20053566992282867, 'loss': 5.61074686050415}
37632 eval {'eval_loss_recons': np.float64(0.6136988833904913), 'eval_contour_violation_ratio': np.float64(0.03994482441519628)}
37760 train {'loss_recons': 0.8327212929725647, 'loss_margin': 4.903323173522949, 'loss_contour': 0.1686110943555832, 'loss': 5.904655456542969}
37760 eval {'eval_loss_recons': np.float64(0.671074646083044), 'eval_contour_violation_ratio': np.float64(0.03977240071268464)}
37888 train {'loss_recons': 0.6790503859519958, 'loss_margin': 2.9687142372131348, 'loss_contour': 0.15267755091190338, 'loss': 3.8004422187805176}
37888 eval {'eval_loss_recons': np.float64(0.590175668845298), 'eval_contour_violation_ratio': np.float64(0.0393126041726536)}
38016 train {'loss_recons': 0.8183913826942444, 'loss_margin': 2.8328137397766113, 'loss_contour': 0.16884353756904602, 'loss': 3.8200485706329346}
38016 eval {'eval_loss_recons': np.float64(0.6338465627082025), 'eval_contour_violation_ratio': np.float64(0.036726248634979024)}
38144 train {'loss_recons': 0.7854008674621582, 'loss_margin': 2.53969144821167, 'loss_contour': 0.13864238560199738, 'loss': 3.4637346267700195}
38144 eval {'eval_loss_recons': np.float64(0.6218104959372569), 'eval_contour_violation_ratio': np.float64(0.036726248634979024)}
38272 train {'loss_recons': 0.7082489728927612, 'loss_margin': 4.5012125968933105, 'loss_contour': 0.14457829296588898, 'loss': 5.354040145874023}
38272 eval {'eval_loss_recons': np.float64(0.6191912640846249), 'eval_contour_violation_ratio': np.float64(0.03850796022759929)}
38400 train {'loss_recons': 0.7137753963470459, 'loss_margin': 4.255044937133789, 'loss_contour': 0.13145191967487335, 'loss': 5.1002726554870605}
38400 eval {'eval_loss_recons': np.float64(0.5947006618899803), 'eval_contour_violation_ratio': np.float64(0.039140180470141964)}
38528 train {'loss_recons': 0.660246729850769, 'loss_margin': 5.293832302093506, 'loss_contour': 0.1493273824453354, 'loss': 6.1034064292907715}
38528 eval {'eval_loss_recons': np.float64(0.5868094559949112), 'eval_contour_violation_ratio': np.float64(0.03879533306511868)}
38656 train {'loss_recons': 0.727311909198761, 'loss_margin': 3.485640048980713, 'loss_contour': 0.12416063994169235, 'loss': 4.337112903594971}
38656 eval {'eval_loss_recons': np.float64(0.6001821032516662), 'eval_contour_violation_ratio': np.float64(0.03810563825507213)}
38784 train {'loss_recons': 0.700287938117981, 'loss_margin': 4.016079902648926, 'loss_contour': 0.19149643182754517, 'loss': 4.907864093780518}
38784 eval {'eval_loss_recons': np.float64(0.6491511637707772), 'eval_contour_violation_ratio': np.float64(0.03753089258003334)}
38912 train {'loss_recons': 0.7265491485595703, 'loss_margin': 3.719076156616211, 'loss_contour': 0.2340698093175888, 'loss': 4.679695129394531}
38912 eval {'eval_loss_recons': np.float64(0.6118933350560731), 'eval_contour_violation_ratio': np.float64(0.039025231335134206)}
39040 train {'loss_recons': 0.5998612642288208, 'loss_margin': 1.9001286029815674, 'loss_contour': 0.1716136783361435, 'loss': 2.6716036796569824}
39040 eval {'eval_loss_recons': np.float64(0.6286207766395713), 'eval_contour_violation_ratio': np.float64(0.03850796022759929)}
39168 train {'loss_recons': 0.7320361137390137, 'loss_margin': 6.035809516906738, 'loss_contour': 0.19182758033275604, 'loss': 6.9596734046936035}
39168 eval {'eval_loss_recons': np.float64(0.6071868176034639), 'eval_contour_violation_ratio': np.float64(0.03689867233749066)}
39296 train {'loss_recons': 0.698677659034729, 'loss_margin': 2.7215147018432617, 'loss_contour': 0.17162977159023285, 'loss': 3.5918219089508057}
39296 eval {'eval_loss_recons': np.float64(0.6132288307560446), 'eval_contour_violation_ratio': np.float64(0.03977240071268464)}
39424 train {'loss_recons': 0.7458354234695435, 'loss_margin': 4.335087776184082, 'loss_contour': 0.2632245719432831, 'loss': 5.344147682189941}
39424 eval {'eval_loss_recons': np.float64(0.6600707113382961), 'eval_contour_violation_ratio': np.float64(0.03810563825507213)}
39552 train {'loss_recons': 0.7370925545692444, 'loss_margin': 5.92979097366333, 'loss_contour': 0.2294246405363083, 'loss': 6.896307945251465}
39552 eval {'eval_loss_recons': np.float64(0.6323582889525405), 'eval_contour_violation_ratio': np.float64(0.037645841715041095)}
39680 train {'loss_recons': 0.773768961429596, 'loss_margin': 3.5248289108276367, 'loss_contour': 0.16963471472263336, 'loss': 4.46823263168335}
39680 eval {'eval_loss_recons': np.float64(0.6676155438451604), 'eval_contour_violation_ratio': np.float64(0.03971492614518076)}
39808 train {'loss_recons': 0.779790997505188, 'loss_margin': 5.2496795654296875, 'loss_contour': 0.14869368076324463, 'loss': 6.178164005279541}
39808 eval {'eval_loss_recons': np.float64(0.615599267650409), 'eval_contour_violation_ratio': np.float64(0.03833553652508765)}
39936 train {'loss_recons': 0.6682083606719971, 'loss_margin': 3.195054292678833, 'loss_contour': 0.140485942363739, 'loss': 4.003748416900635}
39936 eval {'eval_loss_recons': np.float64(0.6175208699651127), 'eval_contour_violation_ratio': np.float64(0.037186045175010056)}
40064 train {'loss_recons': 0.7220306396484375, 'loss_margin': 4.636885643005371, 'loss_contour': 0.19925236701965332, 'loss': 5.558168411254883}
40064 eval {'eval_loss_recons': np.float64(0.5993623081426614), 'eval_contour_violation_ratio': np.float64(0.037588367147537216)}
40192 train {'loss_recons': 0.7385029196739197, 'loss_margin': 3.85528302192688, 'loss_contour': 0.16618338227272034, 'loss': 4.759969234466553}
40192 eval {'eval_loss_recons': np.float64(0.5822965578119865), 'eval_contour_violation_ratio': np.float64(0.03712857060750618)}
40320 train {'loss_recons': 0.6731598973274231, 'loss_margin': 2.623379945755005, 'loss_contour': 0.15499453246593475, 'loss': 3.4515342712402344}
40320 eval {'eval_loss_recons': np.float64(0.5930194772317121), 'eval_contour_violation_ratio': np.float64(0.0370710960400023)}
40448 train {'loss_recons': 0.6907228231430054, 'loss_margin': 2.658843994140625, 'loss_contour': 0.2061513215303421, 'loss': 3.555718183517456}
40448 eval {'eval_loss_recons': np.float64(0.6557208523190639), 'eval_contour_violation_ratio': np.float64(0.03701362147249842)}
40576 train {'loss_recons': 0.6651399731636047, 'loss_margin': 4.110969066619873, 'loss_contour': 0.14459195733070374, 'loss': 4.920701026916504}
40576 eval {'eval_loss_recons': np.float64(0.5844418826618892), 'eval_contour_violation_ratio': np.float64(0.0373584688775217)}
40704 train {'loss_recons': 0.6574372053146362, 'loss_margin': 1.2385836839675903, 'loss_contour': 0.22363345324993134, 'loss': 2.119654417037964}
40704 eval {'eval_loss_recons': np.float64(0.5886037538766421), 'eval_contour_violation_ratio': np.float64(0.04172653600781654)}
40832 train {'loss_recons': 0.6566117405891418, 'loss_margin': 3.7049221992492676, 'loss_contour': 0.22088052332401276, 'loss': 4.582414627075195}
40832 eval {'eval_loss_recons': np.float64(0.6160162854869994), 'eval_contour_violation_ratio': np.float64(0.03845048566009541)}
40960 train {'loss_recons': 0.7009821534156799, 'loss_margin': 4.65592098236084, 'loss_contour': 0.11599583923816681, 'loss': 5.472898960113525}
40960 eval {'eval_loss_recons': np.float64(0.6074440981533139), 'eval_contour_violation_ratio': np.float64(0.036323926662451864)}
41088 train {'loss_recons': 0.7521028518676758, 'loss_margin': 2.5147643089294434, 'loss_contour': 0.12386962026357651, 'loss': 3.3907368183135986}
41088 eval {'eval_loss_recons': np.float64(0.6068190533421004), 'eval_contour_violation_ratio': np.float64(0.03643887579745962)}
41216 train {'loss_recons': 0.7328038215637207, 'loss_margin': 4.103935241699219, 'loss_contour': 0.12348194420337677, 'loss': 4.960220813751221}
41216 eval {'eval_loss_recons': np.float64(0.5998473617070494), 'eval_contour_violation_ratio': np.float64(0.03655382493246738)}
41344 train {'loss_recons': 0.7730164527893066, 'loss_margin': 4.779655456542969, 'loss_contour': 0.11717109382152557, 'loss': 5.6698431968688965}
41344 eval {'eval_loss_recons': np.float64(0.6517577001207023), 'eval_contour_violation_ratio': np.float64(0.03833553652508765)}
41472 train {'loss_recons': 0.7607800364494324, 'loss_margin': 2.8201286792755127, 'loss_contour': 0.24396370351314545, 'loss': 3.8248724937438965}
41472 eval {'eval_loss_recons': np.float64(0.6384323420518135), 'eval_contour_violation_ratio': np.float64(0.039025231335134206)}
41600 train {'loss_recons': 0.6788355708122253, 'loss_margin': 3.637946128845215, 'loss_contour': 0.1524321287870407, 'loss': 4.469213485717773}
41600 eval {'eval_loss_recons': np.float64(0.6138399381492172), 'eval_contour_violation_ratio': np.float64(0.038565434795103166)}
41728 train {'loss_recons': 0.6439909934997559, 'loss_margin': 3.0069656372070312, 'loss_contour': 0.1538219302892685, 'loss': 3.804778575897217}
41728 eval {'eval_loss_recons': np.float64(0.6011515515956556), 'eval_contour_violation_ratio': np.float64(0.03879533306511868)}
41856 train {'loss_recons': 0.8122040033340454, 'loss_margin': 4.174803733825684, 'loss_contour': 0.18076041340827942, 'loss': 5.1677680015563965}
41856 eval {'eval_loss_recons': np.float64(0.6384490689420588), 'eval_contour_violation_ratio': np.float64(0.0412667394677855)}
41984 train {'loss_recons': 0.6763020157814026, 'loss_margin': 2.292271614074707, 'loss_contour': 0.1513792872428894, 'loss': 3.119952917098999}
41984 eval {'eval_loss_recons': np.float64(0.5870793280092936), 'eval_contour_violation_ratio': np.float64(0.03793321455256049)}
42112 train {'loss_recons': 0.775591254234314, 'loss_margin': 6.629715442657471, 'loss_contour': 0.1579856276512146, 'loss': 7.563292503356934}
42112 eval {'eval_loss_recons': np.float64(0.6487278667448059), 'eval_contour_violation_ratio': np.float64(0.03603655382493247)}
42240 train {'loss_recons': 0.7986277341842651, 'loss_margin': 3.88643741607666, 'loss_contour': 0.16926582157611847, 'loss': 4.854331016540527}
42240 eval {'eval_loss_recons': np.float64(0.6286739278928734), 'eval_contour_violation_ratio': np.float64(0.032875452612219094)}
42368 train {'loss_recons': 0.697727620601654, 'loss_margin': 3.8884592056274414, 'loss_contour': 0.1317581683397293, 'loss': 4.717945098876953}
42368 eval {'eval_loss_recons': np.float64(0.5914159770654407), 'eval_contour_violation_ratio': np.float64(0.041036841197769985)}
42496 train {'loss_recons': 0.7300745248794556, 'loss_margin': 5.267002582550049, 'loss_contour': 0.21378350257873535, 'loss': 6.210860252380371}
42496 eval {'eval_loss_recons': np.float64(0.5753796325548967), 'eval_contour_violation_ratio': np.float64(0.0373584688775217)}
42624 train {'loss_recons': 0.6645662784576416, 'loss_margin': 2.5133862495422363, 'loss_contour': 0.16487447917461395, 'loss': 3.342827081680298}
42624 eval {'eval_loss_recons': np.float64(0.6138636920128909), 'eval_contour_violation_ratio': np.float64(0.03776079085004885)}
42752 train {'loss_recons': 0.8972562551498413, 'loss_margin': 4.280614852905273, 'loss_contour': 0.12983861565589905, 'loss': 5.307709693908691}
42752 eval {'eval_loss_recons': np.float64(0.5964889502028644), 'eval_contour_violation_ratio': np.float64(0.0398873498476924)}
42880 train {'loss_recons': 0.6318285465240479, 'loss_margin': 3.638057231903076, 'loss_contour': 0.2366415560245514, 'loss': 4.506527423858643}
42880 eval {'eval_loss_recons': np.float64(0.5953376355519671), 'eval_contour_violation_ratio': np.float64(0.038622909362607045)}
43008 train {'loss_recons': 0.7031564712524414, 'loss_margin': 5.520223617553711, 'loss_contour': 0.15248727798461914, 'loss': 6.3758673667907715}
43008 eval {'eval_loss_recons': np.float64(0.6164073722479804), 'eval_contour_violation_ratio': np.float64(0.037703316282544974)}
43136 train {'loss_recons': 0.7072391510009766, 'loss_margin': 1.4043201208114624, 'loss_contour': 0.15019038319587708, 'loss': 2.261749744415283}
43136 eval {'eval_loss_recons': np.float64(0.6458682622165579), 'eval_contour_violation_ratio': np.float64(0.03879533306511868)}
43264 train {'loss_recons': 0.7215522527694702, 'loss_margin': 2.4517440795898438, 'loss_contour': 0.21003705263137817, 'loss': 3.383333444595337}
43264 eval {'eval_loss_recons': np.float64(0.6283836441570115), 'eval_contour_violation_ratio': np.float64(0.04189895971032818)}
43392 train {'loss_recons': 0.8337807655334473, 'loss_margin': 2.948141098022461, 'loss_contour': 0.18140192329883575, 'loss': 3.9633238315582275}
43392 eval {'eval_loss_recons': np.float64(0.7378903274974133), 'eval_contour_violation_ratio': np.float64(0.03546180814989367)}
43520 train {'loss_recons': 0.6787655353546143, 'loss_margin': 4.12281608581543, 'loss_contour': 0.16892853379249573, 'loss': 4.970510005950928}
43520 eval {'eval_loss_recons': np.float64(0.6191148331655466), 'eval_contour_violation_ratio': np.float64(0.04063451922524283)}
43648 train {'loss_recons': 0.6932456493377686, 'loss_margin': 4.60537052154541, 'loss_contour': 0.21733209490776062, 'loss': 5.515948295593262}
43648 eval {'eval_loss_recons': np.float64(0.5649823610247239), 'eval_contour_violation_ratio': np.float64(0.0393126041726536)}
43776 train {'loss_recons': 0.6437615156173706, 'loss_margin': 3.543862819671631, 'loss_contour': 0.16370779275894165, 'loss': 4.351332187652588}
43776 eval {'eval_loss_recons': np.float64(0.5850402933890382), 'eval_contour_violation_ratio': np.float64(0.038680383930110925)}
43904 train {'loss_recons': 0.6225383281707764, 'loss_margin': 2.7578272819519043, 'loss_contour': 0.16171994805335999, 'loss': 3.542085647583008}
43904 eval {'eval_loss_recons': np.float64(0.5573407750096482), 'eval_contour_violation_ratio': np.float64(0.037645841715041095)}
44032 train {'loss_recons': 0.6711007952690125, 'loss_margin': 4.518608093261719, 'loss_contour': 0.15795451402664185, 'loss': 5.347663402557373}
44032 eval {'eval_loss_recons': np.float64(0.5638130545563519), 'eval_contour_violation_ratio': np.float64(0.03839301109259153)}
44160 train {'loss_recons': 0.6774958372116089, 'loss_margin': 2.584315299987793, 'loss_contour': 0.16394105553627014, 'loss': 3.4257524013519287}
44160 eval {'eval_loss_recons': np.float64(0.5970969497359868), 'eval_contour_violation_ratio': np.float64(0.03954250244266912)}
44288 train {'loss_recons': 0.7228914499282837, 'loss_margin': 3.377756118774414, 'loss_contour': 0.5310730338096619, 'loss': 4.631720542907715}
44288 eval {'eval_loss_recons': np.float64(0.6322758066146086), 'eval_contour_violation_ratio': np.float64(0.037645841715041095)}
44416 train {'loss_recons': 0.5796800851821899, 'loss_margin': 4.894165992736816, 'loss_contour': 0.14954639971256256, 'loss': 5.623392581939697}
44416 eval {'eval_loss_recons': np.float64(0.5786276763590742), 'eval_contour_violation_ratio': np.float64(0.036726248634979024)}
44544 train {'loss_recons': 0.7617074251174927, 'loss_margin': 6.021209239959717, 'loss_contour': 0.12148799747228622, 'loss': 6.904404640197754}
44544 eval {'eval_loss_recons': np.float64(0.6099310898693094), 'eval_contour_violation_ratio': np.float64(0.03845048566009541)}
44672 train {'loss_recons': 0.7461587190628052, 'loss_margin': 7.063604831695557, 'loss_contour': 0.15913988649845123, 'loss': 7.968903541564941}
44672 eval {'eval_loss_recons': np.float64(0.5839588054895481), 'eval_contour_violation_ratio': np.float64(0.037645841715041095)}
44800 train {'loss_recons': 0.9437198638916016, 'loss_margin': 4.899690628051758, 'loss_contour': 0.22909684479236603, 'loss': 6.072507381439209}
44800 eval {'eval_loss_recons': np.float64(0.654373959300477), 'eval_contour_violation_ratio': np.float64(0.041439163170297146)}
44928 train {'loss_recons': 0.7194044589996338, 'loss_margin': 3.518662929534912, 'loss_contour': 0.2320263832807541, 'loss': 4.470094203948975}
44928 eval {'eval_loss_recons': np.float64(0.5533633526291167), 'eval_contour_violation_ratio': np.float64(0.036668774067475145)}
45056 train {'loss_recons': 0.7066574692726135, 'loss_margin': 4.322443008422852, 'loss_contour': 0.18548274040222168, 'loss': 5.214583396911621}
45056 eval {'eval_loss_recons': np.float64(0.5513039183475936), 'eval_contour_violation_ratio': np.float64(0.0373584688775217)}
45184 train {'loss_recons': 0.6584828495979309, 'loss_margin': 5.568305015563965, 'loss_contour': 0.14823992550373077, 'loss': 6.375028133392334}
45184 eval {'eval_loss_recons': np.float64(0.5869062685528229), 'eval_contour_violation_ratio': np.float64(0.039025231335134206)}
45312 train {'loss_recons': 0.7782989740371704, 'loss_margin': 4.518729209899902, 'loss_contour': 0.15535497665405273, 'loss': 5.452383041381836}
45312 eval {'eval_loss_recons': np.float64(0.6355383030924753), 'eval_contour_violation_ratio': np.float64(0.03689867233749066)}
45440 train {'loss_recons': 0.6377425193786621, 'loss_margin': 2.4439945220947266, 'loss_contour': 0.21302801370620728, 'loss': 3.294764995574951}
45440 eval {'eval_loss_recons': np.float64(0.5738724398589067), 'eval_contour_violation_ratio': np.float64(0.039082705902638085)}
45568 train {'loss_recons': 0.6727858781814575, 'loss_margin': 5.094447135925293, 'loss_contour': 0.1647137552499771, 'loss': 5.931946754455566}
45568 eval {'eval_loss_recons': np.float64(0.5652874369655756), 'eval_contour_violation_ratio': np.float64(0.03689867233749066)}
45696 train {'loss_recons': 0.7214422225952148, 'loss_margin': 3.7900540828704834, 'loss_contour': 0.15512420237064362, 'loss': 4.66662073135376}
45696 eval {'eval_loss_recons': np.float64(0.595194798489747), 'eval_contour_violation_ratio': np.float64(0.037588367147537216)}
45824 train {'loss_recons': 0.7875858545303345, 'loss_margin': 2.588275671005249, 'loss_contour': 0.14189064502716064, 'loss': 3.517752170562744}
45824 eval {'eval_loss_recons': np.float64(0.5490466473195732), 'eval_contour_violation_ratio': np.float64(0.038680383930110925)}
45952 train {'loss_recons': 0.7074974775314331, 'loss_margin': 4.983129024505615, 'loss_contour': 0.18251176178455353, 'loss': 5.873138427734375}
45952 eval {'eval_loss_recons': np.float64(0.5870599418438818), 'eval_contour_violation_ratio': np.float64(0.03753089258003334)}
46080 train {'loss_recons': 0.6833992004394531, 'loss_margin': 2.2345657348632812, 'loss_contour': 0.15255728363990784, 'loss': 3.0705223083496094}
46080 eval {'eval_loss_recons': np.float64(0.5821908100233685), 'eval_contour_violation_ratio': np.float64(0.03850796022759929)}
46208 train {'loss_recons': 0.6504513025283813, 'loss_margin': 4.242520332336426, 'loss_contour': 0.13012142479419708, 'loss': 5.023092746734619}
46208 eval {'eval_loss_recons': np.float64(0.5552534969775367), 'eval_contour_violation_ratio': np.float64(0.03712857060750618)}
46336 train {'loss_recons': 0.6567684412002563, 'loss_margin': 7.231385231018066, 'loss_contour': 0.1974174827337265, 'loss': 8.0855712890625}
46336 eval {'eval_loss_recons': np.float64(0.5736098555721417), 'eval_contour_violation_ratio': np.float64(0.03609402839243635)}
46464 train {'loss_recons': 0.6901776790618896, 'loss_margin': 1.2250502109527588, 'loss_contour': 0.15304049849510193, 'loss': 2.068268299102783}
46464 eval {'eval_loss_recons': np.float64(0.5934658710194414), 'eval_contour_violation_ratio': np.float64(0.03816311282257601)}
46592 train {'loss_recons': 0.6746028661727905, 'loss_margin': 4.355408191680908, 'loss_contour': 0.12887494266033173, 'loss': 5.158885955810547}
46592 eval {'eval_loss_recons': np.float64(0.5654594376652432), 'eval_contour_violation_ratio': np.float64(0.03695614690499454)}
46720 train {'loss_recons': 0.6699822545051575, 'loss_margin': 4.707212448120117, 'loss_contour': 0.1315913200378418, 'loss': 5.508786201477051}
46720 eval {'eval_loss_recons': np.float64(0.6060700409122487), 'eval_contour_violation_ratio': np.float64(0.040979366630266106)}
46848 train {'loss_recons': 0.649764895439148, 'loss_margin': 4.0592780113220215, 'loss_contour': 0.14862866699695587, 'loss': 4.857671737670898}
46848 eval {'eval_loss_recons': np.float64(0.5656811428801843), 'eval_contour_violation_ratio': np.float64(0.03655382493246738)}
46976 train {'loss_recons': 0.6609363555908203, 'loss_margin': 3.087799072265625, 'loss_contour': 0.20605793595314026, 'loss': 3.9547934532165527}
46976 eval {'eval_loss_recons': np.float64(0.551753459575746), 'eval_contour_violation_ratio': np.float64(0.03799068912006437)}
47104 train {'loss_recons': 0.7446390390396118, 'loss_margin': 4.588779449462891, 'loss_contour': 0.14969462156295776, 'loss': 5.483112812042236}
47104 eval {'eval_loss_recons': np.float64(0.597404651356894), 'eval_contour_violation_ratio': np.float64(0.03505948617736652)}
47232 train {'loss_recons': 0.771014392375946, 'loss_margin': 4.522623062133789, 'loss_contour': 0.17982371151447296, 'loss': 5.473461151123047}
47232 eval {'eval_loss_recons': np.float64(0.5703012446045808), 'eval_contour_violation_ratio': np.float64(0.038737858497614804)}
47360 train {'loss_recons': 0.5972283482551575, 'loss_margin': 5.40104866027832, 'loss_contour': 0.21139806509017944, 'loss': 6.209675312042236}
47360 eval {'eval_loss_recons': np.float64(0.5445502922741774), 'eval_contour_violation_ratio': np.float64(0.03689867233749066)}
47488 train {'loss_recons': 0.6788458824157715, 'loss_margin': 3.8603382110595703, 'loss_contour': 0.14283481240272522, 'loss': 4.682018756866455}
47488 eval {'eval_loss_recons': np.float64(0.564155612498717), 'eval_contour_violation_ratio': np.float64(0.037300994310017814)}
47616 train {'loss_recons': 0.5753772258758545, 'loss_margin': 2.143264055252075, 'loss_contour': 0.1529710590839386, 'loss': 2.871612310409546}
47616 eval {'eval_loss_recons': np.float64(0.570662193579054), 'eval_contour_violation_ratio': np.float64(0.03793321455256049)}
47744 train {'loss_recons': 0.6666622161865234, 'loss_margin': 4.050724029541016, 'loss_contour': 0.20738571882247925, 'loss': 4.924771785736084}
47744 eval {'eval_loss_recons': np.float64(0.5514287971718924), 'eval_contour_violation_ratio': np.float64(0.034312316799816084)}
47872 train {'loss_recons': 0.7904454469680786, 'loss_margin': 5.661798477172852, 'loss_contour': 0.13961350917816162, 'loss': 6.591857433319092}
47872 eval {'eval_loss_recons': np.float64(0.5517991853551187), 'eval_contour_violation_ratio': np.float64(0.038680383930110925)}
48000 train {'loss_recons': 0.720602810382843, 'loss_margin': 7.587973117828369, 'loss_contour': 0.16197051107883453, 'loss': 8.470545768737793}
48000 eval {'eval_loss_recons': np.float64(0.5811905298477454), 'eval_contour_violation_ratio': np.float64(0.037645841715041095)}
48128 train {'loss_recons': 0.6315373182296753, 'loss_margin': 2.822788953781128, 'loss_contour': 0.1401870846748352, 'loss': 3.594513177871704}
48128 eval {'eval_loss_recons': np.float64(0.5712254732605249), 'eval_contour_violation_ratio': np.float64(0.03948502787516524)}
48256 train {'loss_recons': 0.5852078199386597, 'loss_margin': 4.518293857574463, 'loss_contour': 0.11779459565877914, 'loss': 5.221296310424805}
48256 eval {'eval_loss_recons': np.float64(0.5394010464559916), 'eval_contour_violation_ratio': np.float64(0.03781826541755273)}
48384 train {'loss_recons': 0.6451920866966248, 'loss_margin': 2.373762845993042, 'loss_contour': 0.13613708317279816, 'loss': 3.1550920009613037}
48384 eval {'eval_loss_recons': np.float64(0.5896213888603943), 'eval_contour_violation_ratio': np.float64(0.0367837232024829)}
48512 train {'loss_recons': 0.6782872676849365, 'loss_margin': 5.009637832641602, 'loss_contour': 0.1448626071214676, 'loss': 5.832787990570068}
48512 eval {'eval_loss_recons': np.float64(0.5784477690138466), 'eval_contour_violation_ratio': np.float64(0.03781826541755273)}
48640 train {'loss_recons': 0.610794723033905, 'loss_margin': 3.876784324645996, 'loss_contour': 0.12291423231363297, 'loss': 4.610493183135986}
48640 eval {'eval_loss_recons': np.float64(0.5436631482904959), 'eval_contour_violation_ratio': np.float64(0.037588367147537216)}
48768 train {'loss_recons': 0.7083148956298828, 'loss_margin': 5.288625240325928, 'loss_contour': 0.17755889892578125, 'loss': 6.174499034881592}
48768 eval {'eval_loss_recons': np.float64(0.5324512447522844), 'eval_contour_violation_ratio': np.float64(0.03689867233749066)}
48896 train {'loss_recons': 0.5744513869285583, 'loss_margin': 2.00093412399292, 'loss_contour': 0.17946940660476685, 'loss': 2.754854917526245}
48896 eval {'eval_loss_recons': np.float64(0.561336333903913), 'eval_contour_violation_ratio': np.float64(0.03919765503764584)}
49024 train {'loss_recons': 0.7533805966377258, 'loss_margin': 3.9441051483154297, 'loss_contour': 0.2248910367488861, 'loss': 4.922377109527588}
49024 eval {'eval_loss_recons': np.float64(0.5549753917633887), 'eval_contour_violation_ratio': np.float64(0.03695614690499454)}
49152 train {'loss_recons': 0.689679741859436, 'loss_margin': 4.488298416137695, 'loss_contour': 0.1773175597190857, 'loss': 5.355295658111572}
49152 eval {'eval_loss_recons': np.float64(0.5620323264001389), 'eval_contour_violation_ratio': np.float64(0.037243519742513935)}
49280 train {'loss_recons': 0.6067517995834351, 'loss_margin': 1.5116987228393555, 'loss_contour': 0.1452207714319229, 'loss': 2.263671398162842}
49280 eval {'eval_loss_recons': np.float64(0.5365317756851142), 'eval_contour_violation_ratio': np.float64(0.03787573998505661)}
49408 train {'loss_recons': 0.8014147877693176, 'loss_margin': 9.400764465332031, 'loss_contour': 0.1314154863357544, 'loss': 10.33359432220459}
49408 eval {'eval_loss_recons': np.float64(0.6181349976550203), 'eval_contour_violation_ratio': np.float64(0.03804816368756825)}
49536 train {'loss_recons': 0.6931096315383911, 'loss_margin': 2.8202810287475586, 'loss_contour': 0.14014475047588348, 'loss': 3.6535353660583496}
49536 eval {'eval_loss_recons': np.float64(0.6129879371332354), 'eval_contour_violation_ratio': np.float64(0.039140180470141964)}
49664 train {'loss_recons': 0.6657178401947021, 'loss_margin': 3.3024375438690186, 'loss_contour': 0.1373748481273651, 'loss': 4.105530261993408}
49664 eval {'eval_loss_recons': np.float64(0.586616873157885), 'eval_contour_violation_ratio': np.float64(0.04218633254784758)}
49792 train {'loss_recons': 0.5995160341262817, 'loss_margin': 4.378806114196777, 'loss_contour': 0.19895941019058228, 'loss': 5.177281379699707}
49792 eval {'eval_loss_recons': np.float64(0.5400453304759552), 'eval_contour_violation_ratio': np.float64(0.036668774067475145)}
49920 train {'loss_recons': 0.7115693092346191, 'loss_margin': 3.8921220302581787, 'loss_contour': 0.10671568661928177, 'loss': 4.71040678024292}
49920 eval {'eval_loss_recons': np.float64(0.6667578179908062), 'eval_contour_violation_ratio': np.float64(0.0367837232024829)}

Ignoring Step 4 and 5 (NEW)¶

We chose to ignore Step 4 and 5, as the old interface would not be compatible with a model that expects 26 bins.

After Step 5, we will describe our new pipeline for encoding what the user types and interpreting that output file.

In [ ]:
# # @title **(Step 4)** Port trained decoder parameters to Tensorflow.js format

# # @markdown In this step, we will use the TensorFlow.js Python library to export our model's parameters in a binary format, to be loaded later by the JavaScript client.

# !!pip install tensorflowjs

# from tensorflowjs.write_weights import write_weights

# # Load saved model dict
# d = torch.load("piano_genie/model.pt", map_location=torch.device("cpu"))
# d = {k: v.numpy() for k, v in d.items()}

# # Convert to tensorflow-js format
# pathlib.Path("piano_genie/dec_tfjs").mkdir(exist_ok=True)
# write_weights(
#     [[{"name": k, "data": v} for k, v in d.items() if k.startswith("dec")]],
#     "piano_genie/dec_tfjs",
# )
In [ ]:
# # @title **(Step 5)** Create test case to check correctness of JavaScript port

# # @markdown Finally, we will serialize a sequence of inputs to and outputs from our trained model to create a test case for our JavaScript reimplementation.
# # @markdown This is critically important—I have ported many models from Python to JavaScript and have yet to get it right on the first try.
# # @markdown Porting models from PyTorch to TensorFlow.js is additionally tricky because parameters of the same shape are often used differently by the two APIs.

# # Restore model from saved checkpoint
# device = torch.device("cpu")
# with open("piano_genie/cfg.json", "r") as f:
#     cfg = json.load(f)
# model = PianoGenieAutoencoder(cfg)
# model.load_state_dict(torch.load("piano_genie/model.pt", map_location=device))
# model.eval()
# model.to(device)

# # Serialize a batch of inputs/outputs as JSON
# with torch.no_grad():
#     ground_truth_keys, input_dts = performances_to_batch(
#         [DATASET["validation"][0]], device, train=False
#     )
#     output_logits, input_buttons = model(ground_truth_keys, input_dts)
#     input_buttons = model.quant.real_to_discrete(input_buttons)

#     input_dts = input_dts[0].cpu().numpy().tolist()
#     ground_truth_keys = ground_truth_keys[0].cpu().numpy().tolist()
#     input_keys = [PIANO_NUM_KEYS] + ground_truth_keys[:-1]
#     input_buttons = input_buttons[0].cpu().numpy().tolist()
#     output_logits = output_logits[0].cpu().numpy().tolist()

#     test = {
#         n: eval(n)
#         for n in ["input_dts", "input_keys", "input_buttons", "output_logits"]
#     }
#     with open(pathlib.Path("piano_genie", "test.json"), "w") as f:
#         f.write(json.dumps(test))

New Step 4 and 5 (NEW)¶

Taking User Input¶

We created a new HTML website that encodes user input as a CSV file. The file can be found on our Github here: https://github.com/AnniePhan02/CSE253-Assignment2/blob/main/task2/index.html

Feeding user input file into the model¶

Running Steps 1-5 will result in two files: cfg.json and model.pt.

The key function to generate output from the model is the step() function.

Example CFG.json

{
  "seed": 0,
  "num_buttons": 26,
  "data_delta_time_max": 1.0,
  "data_augment_time_stretch_max": 0.05,
  "data_augment_transpose_max": 6,
  "model_rnn_dim": 128,
  "model_rnn_num_layers": 2,
  "batch_size": 32,
  "seq_len": 128,
  "lr": 0.0003,
  "loss_margin_multiplier": 1.0,
  "loss_contour_multiplier": 1.0,
  "summarize_frequency": 128,
  "eval_frequency": 128,
  "max_num_steps": 50000
}

Running the Newly Generated Model¶

The repository we pulled from had some base code for training the model, but not for running it. We created the following script with the assistance of copilot in order to feed input and interpret output from the model.

Defining Bins Based On Keys¶

Since we have 26 keys, we made 26 bins

def letter_to_button_26(letter):
    # Map letters to button indices for a 26-letter keyboard layout
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    if letter in alphabet:
        return alphabet.index(letter)
    else:
        return 0  # Default case for unsupported characters

Single Note Eval¶

In [58]:
import torch
import json

cfg = json.load(open("cfg.json"))
model = PianoGenieAutoencoder(cfg)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()

# 2) Prepare decoder state for a single stream
#    We’ll run batch_size=1 since we’re in interactive mode
h = model.dec.init_hidden(batch_size=1)
# k_prev holds the last output key index; start with SOS
k_prev = torch.full((1, 1), SOS, dtype=torch.long)

# 3) Now, each time the user presses a button, you have:
#    b_i: integer 0…7    (which button)
#    t_i: float          (absolute onset time in seconds)
#    v_i: int 1…127      (velocity)
#
# You’ll convert these into tensors, call the decoder, sample/argmax,
# and then feed that key back in as the next k_prev.


def step(b_i, t_i, v_i, k_prev, h):
    # 3a) button needs to be the *real-valued* centroid in [–1,1]
    b_real = model.quant.discrete_to_real(torch.tensor([[b_i]]))  # → shape (1,1)
    # 3b) wrap time & velocity
    t = torch.tensor([[t_i]], dtype=torch.float)
    v = torch.tensor([[v_i]], dtype=torch.float)

    # 3c) run decoder for one timestep
    with torch.no_grad():
        logits, h = model.dec(k_prev, t, b_real, v, h)
        # logits: (1,1,88)
        probs = torch.softmax(logits[0, 0], dim=-1)
        k_i = torch.multinomial(probs, num_samples=1)  # or .argmax()

    return k_i.reshape(1, 1), h, probs


# 4) Example usage:
#    Suppose the user hits button 3 at time=0.57s with velocity=90:
# k1, h, p = step(b_i=3, t_i=0.57, v_i=90, k_prev=k_prev, h=h)


# version 1
def letter_to_button_keyboard(letter):
    # Map letters on the keyboard to button indices, top row, middle row, bottom row
    top = "qwertyuiop"
    middle = "asdfghjkl"
    bottom = "zxcvbnm"
    if letter in top:
        return min(top.index(letter), 8), 40
    elif letter in middle:
        return min(middle.index(letter), 8), 80
    elif letter in bottom:
        return min(bottom.index(letter), 8), 120
    else:
        return 0, 0


def letter_to_button_26(letter):
    # Map letters to button indices for a 26-letter keyboard layout
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    if letter in alphabet:
        return alphabet.index(letter)
    else:
        return 0  # Default case for unsupported characters


# read typing_intervals.csv
import csv

# 4) Example usage:
#    Suppose the user hits button 3 at time=0.57s with velocity=90:
# k1, h, p = step(b_i=3, t_i=0.57, v_i=90, k_prev=k_prev, h=h)

notes = []

filename = "typing_wpm_timestamps.csv"
filename_no_ext = filename.split(".")[0]

# To generate `eval.txt`, run this cell and pipe output to a file

with open(filename, "r") as f:
    reader = csv.reader(f)
    # skip header
    next(reader)
    for row in reader:
        print(row)
        letter, time, wpm = row[0], row[1], row[2]

        if not letter or not time:
            print("Skipping empty row")
            continue

        # check if wpm is numeric
        if "Inf" in wpm:
            print(f"Skipping non-numeric wpm: {wpm}")
            continue

        time = float(time)
        wpm = float(wpm)

        print(letter, time, wpm)
        # convert letter to button index and velocity
        letter = letter.lower()

        # velocity is used from mapping function, mapping is based on keyboard layout
        button = letter_to_button_26(letter)
        velocity = int(wpm * 2)
        # ensure velocity is in range 1-127
        velocity = max(1, min(127, velocity))
        k_prev, h, probs = step(b_i=button, t_i=time, v_i=velocity, k_prev=k_prev, h=h)
        notes.append((button, k_prev.item(), time, velocity))
print(notes)
# generate the midi file
import pretty_midi
import time

pm = pretty_midi.PrettyMIDI()
instr = pretty_midi.Instrument(program=0)

for i, (button, note, onset, vel) in enumerate(notes):
    # define a duration for each note
    if i + 1 < len(notes):
        end = notes[i + 1][1]
    else:
        end = onset + 0.5
    pm_note = pretty_midi.Note(velocity=vel, pitch=note, start=onset, end=end)
    print(pm_note)
    instr.notes.append(pm_note)

pm.instruments.append(instr)
filename = f"output_{time.time()}_{filename_no_ext}.mid"
pm.write(filename) 
['I', '3.496', '10.3']
I 3.496 10.3
[' ', '3.622', '13.3']
  3.622 13.3
['w', '3.697', '16.2']
w 3.697 16.2
['Backspace', '3.957', '18.2']
Backspace 3.957 18.2
['a', '4.023', '20.9']
a 4.023 20.9
['m', '4.121', '23.3']
m 4.121 23.3
[' ', '4.229', '25.5']
  4.229 25.5
['g', '4.298', '27.9']
g 4.298 27.9
['o', '4.399', '30.0']
o 4.399 30.0
['i', '4.538', '31.7']
i 4.538 31.7
['n', '4.591', '34.0']
n 4.591 34.0
['g', '4.665', '36.0']
g 4.665 36.0
[' ', '4.731', '38.0']
  4.731 38.0
['t', '4.798', '40.0']
t 4.798 40.0
['o', '4.837', '42.2']
o 4.837 42.2
[' ', '4.923', '43.9']
  4.923 43.9
['p', '5.040', '45.2']
p 5.04 45.2
['a', '5.102', '47.0']
a 5.102 47.0
['s', '5.246', '48.0']
s 5.246 48.0
['s', '5.367', '49.2']
s 5.367 49.2
[' ', '5.460', '50.5']
  5.46 50.5
['Shift', '5.741', '50.2']
Shift 5.741 50.2
['C', '5.798', '51.7']
C 5.798 51.7
['S', '5.948', '52.5']
S 5.948 52.5
['E', '6.136', '52.8']
E 6.136 52.8
[' ', '6.335', '53.0']
  6.335 53.0
['2', '6.525', '53.3']
2 6.525 53.3
['1', '6.571', '54.8']
1 6.571 54.8
['Backspace', '6.903', '53.9']
Backspace 6.903 53.9
['Backspace', '7.038', '54.6']
Backspace 7.038 54.6
['1', '7.082', '55.9']
1 7.082 55.9
['2', '7.159', '57.0']
2 7.159 57.0
['5', '7.328', '57.3']
5 7.328 57.3
[' ', '7.531', '57.4']
  7.531 57.4
['g', '7.778', '57.1']
g 7.778 57.1
['o', '7.887', '57.8']
o 7.887 57.8
['d', '8.012', '58.4']
d 8.012 58.4
['d', '8.140', '59.0']
d 8.14 59.0
['a', '8.586', '57.3']
a 8.586 57.3
['m', '8.694', '58.0']
m 8.694 58.0
['n', '8.836', '58.4']
n 8.836 58.4
[' ', '8.975', '58.8']
  8.975 58.8
['i', '9.085', '59.4']
i 9.085 59.4
['t', '9.196', '60.0']
t 9.196 60.0
['.', '9.390', '60.1']
. 9.39 60.1
[(8, 35, 3.496, 20), (0, 30, 3.622, 26), (22, 48, 3.697, 32), (0, 28, 3.957, 36), (0, 31, 4.023, 41), (12, 42, 4.121, 46), (0, 31, 4.229, 51), (6, 39, 4.298, 55), (14, 44, 4.399, 60), (8, 40, 4.538, 63), (13, 44, 4.591, 68), (6, 39, 4.665, 72), (0, 32, 4.731, 76), (19, 49, 4.798, 80), (14, 46, 4.837, 84), (0, 32, 4.923, 87), (15, 47, 5.04, 90), (0, 32, 5.102, 94), (18, 50, 5.246, 96), (18, 48, 5.367, 98), (0, 32, 5.46, 101), (0, 32, 5.741, 100), (2, 37, 5.798, 103), (18, 48, 5.948, 105), (4, 37, 6.136, 105), (0, 32, 6.335, 106), (0, 33, 6.525, 106), (0, 32, 6.571, 109), (0, 32, 6.903, 107), (0, 32, 7.038, 109), (0, 32, 7.082, 111), (0, 32, 7.159, 114), (0, 32, 7.328, 114), (0, 32, 7.531, 114), (6, 40, 7.778, 114), (14, 46, 7.887, 115), (3, 38, 8.012, 116), (3, 38, 8.14, 118), (0, 32, 8.586, 114), (12, 46, 8.694, 116), (13, 46, 8.836, 116), (0, 32, 8.975, 117), (8, 43, 9.085, 118), (19, 50, 9.196, 120), (0, 33, 9.39, 120)]
Note(start=3.496000, end=30.000000, pitch=35, velocity=20)
Note(start=3.622000, end=48.000000, pitch=30, velocity=26)
Note(start=3.697000, end=28.000000, pitch=48, velocity=32)
Note(start=3.957000, end=31.000000, pitch=28, velocity=36)
Note(start=4.023000, end=42.000000, pitch=31, velocity=41)
Note(start=4.121000, end=31.000000, pitch=42, velocity=46)
Note(start=4.229000, end=39.000000, pitch=31, velocity=51)
Note(start=4.298000, end=44.000000, pitch=39, velocity=55)
Note(start=4.399000, end=40.000000, pitch=44, velocity=60)
Note(start=4.538000, end=44.000000, pitch=40, velocity=63)
Note(start=4.591000, end=39.000000, pitch=44, velocity=68)
Note(start=4.665000, end=32.000000, pitch=39, velocity=72)
Note(start=4.731000, end=49.000000, pitch=32, velocity=76)
Note(start=4.798000, end=46.000000, pitch=49, velocity=80)
Note(start=4.837000, end=32.000000, pitch=46, velocity=84)
Note(start=4.923000, end=47.000000, pitch=32, velocity=87)
Note(start=5.040000, end=32.000000, pitch=47, velocity=90)
Note(start=5.102000, end=50.000000, pitch=32, velocity=94)
Note(start=5.246000, end=48.000000, pitch=50, velocity=96)
Note(start=5.367000, end=32.000000, pitch=48, velocity=98)
Note(start=5.460000, end=32.000000, pitch=32, velocity=101)
Note(start=5.741000, end=37.000000, pitch=32, velocity=100)
Note(start=5.798000, end=48.000000, pitch=37, velocity=103)
Note(start=5.948000, end=37.000000, pitch=48, velocity=105)
Note(start=6.136000, end=32.000000, pitch=37, velocity=105)
Note(start=6.335000, end=33.000000, pitch=32, velocity=106)
Note(start=6.525000, end=32.000000, pitch=33, velocity=106)
Note(start=6.571000, end=32.000000, pitch=32, velocity=109)
Note(start=6.903000, end=32.000000, pitch=32, velocity=107)
Note(start=7.038000, end=32.000000, pitch=32, velocity=109)
Note(start=7.082000, end=32.000000, pitch=32, velocity=111)
Note(start=7.159000, end=32.000000, pitch=32, velocity=114)
Note(start=7.328000, end=32.000000, pitch=32, velocity=114)
Note(start=7.531000, end=40.000000, pitch=32, velocity=114)
Note(start=7.778000, end=46.000000, pitch=40, velocity=114)
Note(start=7.887000, end=38.000000, pitch=46, velocity=115)
Note(start=8.012000, end=38.000000, pitch=38, velocity=116)
Note(start=8.140000, end=32.000000, pitch=38, velocity=118)
Note(start=8.586000, end=46.000000, pitch=32, velocity=114)
Note(start=8.694000, end=46.000000, pitch=46, velocity=116)
Note(start=8.836000, end=32.000000, pitch=46, velocity=116)
Note(start=8.975000, end=43.000000, pitch=32, velocity=117)
Note(start=9.085000, end=50.000000, pitch=43, velocity=118)
Note(start=9.196000, end=33.000000, pitch=50, velocity=120)
Note(start=9.390000, end=9.890000, pitch=33, velocity=120)

Eval.txt¶

The dropdown below shows that our

eval.txt ``` File: baseline_english_1_1748804204.7677963.mid (4, 0.6468253968253969) Key i: Pitches: 28

Key t: Pitches: 25

Key w: Pitches: 22

Key a: Pitches: 31

Key s: Pitches: 32

Key b: Pitches: 44

Key r: Pitches: 24

Key g: Pitches: 35

Key h: Pitches: 36

Key c: Pitches: 42

Key o: Pitches: 29

Key l: Pitches: 39

Key d: Pitches: 33

Key y: Pitches: 26

Key n: Pitches: 45

Key p: Pitches: 30

Key e: Pitches: 23

Key k: Pitches: 38

Key m: Pitches: 46

Key u: Pitches: 27

Key z: Pitches: 40

Key f: Pitches: 34

Key v: Pitches: 43

Key q: Pitches: 21

File: baseline_english_2_1748804270.5347235.mid (0, 0.6462882096069869) Key t: Pitches: 25

Key h: Pitches: 36

Key e: Pitches: 23

Key l: Pitches: 39

Key a: Pitches: 31

Key n: Pitches: 45

Key w: Pitches: 22

Key i: Pitches: 28

Key d: Pitches: 33

Key m: Pitches: 46

Key u: Pitches: 27

Key c: Pitches: 42

Key o: Pitches: 29

Key f: Pitches: 34

Key p: Pitches: 30

Key s: Pitches: 32

Key r: Pitches: 24

Key k: Pitches: 38

Key g: Pitches: 35

Key b: Pitches: 44

Key q: Pitches: 21

Key y: Pitches: 26

Key v: Pitches: 43

File: baseline_english_3_1748827500.4073334.mid (0, 0.6820276497695853) Key i: Pitches: 28

Key f: Pitches: 34

Key m: Pitches: 46

Key u: Pitches: 27

Key s: Pitches: 32

Key c: Pitches: 42

Key b: Pitches: 44

Key e: Pitches: 23

Key t: Pitches: 25

Key h: Pitches: 36

Key o: Pitches: 29

Key d: Pitches: 33

Key l: Pitches: 39

Key v: Pitches: 43

Key p: Pitches: 30

Key a: Pitches: 31

Key y: Pitches: 26

Key n: Pitches: 45

Key g: Pitches: 35

Key x: Pitches: 41

Key r: Pitches: 24

Key k: Pitches: 38

Key w: Pitches: 22

File: baseline_english_3_1748827756.1898053.mid (0, 0.6820276497695853) Key i: Pitches: 28

Key f: Pitches: 34

Key m: Pitches: 46

Key u: Pitches: 27

Key s: Pitches: 32

Key c: Pitches: 42

Key b: Pitches: 44

Key e: Pitches: 23

Key t: Pitches: 25

Key h: Pitches: 36

Key o: Pitches: 29

Key d: Pitches: 33

Key l: Pitches: 39

Key v: Pitches: 43

Key p: Pitches: 30

Key a: Pitches: 31

Key y: Pitches: 26

Key n: Pitches: 45

Key g: Pitches: 35

Key x: Pitches: 41

Key r: Pitches: 24

Key k: Pitches: 38

Key w: Pitches: 22

File: baseline_performance_1_1748805631.1922073.mid (4, 0.6428571428571429) Key a: Pitches: 31

Key d: Pitches: 33

Key g: Pitches: 35

Key s: Pitches: 32

Key p: Pitches: 30

Key o: Pitches: 29

Key f: Pitches: 34

File: baseline_performance_2_1748805691.4368095.mid (5, 0.7096774193548387) Key z: Pitches: 40

Key v: Pitches: 43

Key b: Pitches: 44

Key n: Pitches: 45

Key k: Pitches: 38

Key m: Pitches: 46

Key f: Pitches: 34

Key j: Pitches: 37

Key l: Pitches: 39

Key d: Pitches: 33

File: baseline_performance_3_1748827470.3868184.mid (0, 0.7777777777777778) Key q: Pitches: 21

Key z: Pitches: 40

Key x: Pitches: 41

Key v: Pitches: 43

Key c: Pitches: 42

Key n: Pitches: 45

File: output_english_1_1748804467.0220675.mid (1, 0.6865079365079365) Key i: Pitches: 42, 40, 41, 44, 45, 46, 47, 48

Key t: Pitches: 49, 48, 50, 51, 52, 53, 54, 57, 56

Key w: Pitches: 48, 51, 50, 56, 58, 60

Key a: Pitches: 30, 31, 32, 33, 37

Key s: Pitches: 47, 48, 49, 50, 51, 53, 55, 54, 58, 57

Key b: Pitches: 32, 33

Key r: Pitches: 46, 48, 47, 50, 49, 51, 52, 56

Key g: Pitches: 39, 40, 42, 43, 44, 45

Key h: Pitches: 41, 42, 44, 45, 47

Key c: Pitches: 36, 37, 38, 39

Key o: Pitches: 46, 47, 48, 50, 49, 51, 53, 55

Key l: Pitches: 43, 44, 45, 46, 47, 49, 51

Key d: Pitches: 37, 38, 39, 40

Key y: Pitches: 53, 57, 58, 61

Key n: Pitches: 45, 46, 47, 50, 49, 52, 53, 54

Key p: Pitches: 46, 48, 49, 53

Key e: Pitches: 39, 38, 40, 41, 44

Key k: Pitches: 44, 43, 45, 48

Key m: Pitches: 44, 47, 51

Key u: Pitches: 51, 53, 54, 56, 58

Key z: Pitches: 54, 56

Key f: Pitches: 40, 39, 41, 44

Key v: Pitches: 54, 56, 57

Key q: Pitches: 51, 54

File: output_english_1_1748804472.0935662.mid (1, 0.6904761904761905) Key i: Pitches: 43, 40, 41, 42, 44, 45, 46, 47, 48

Key t: Pitches: 48, 50, 49, 51, 53, 54, 56, 57, 58

Key w: Pitches: 48, 51, 57, 58, 61

Key a: Pitches: 31, 30, 32, 33, 36, 37

Key s: Pitches: 47, 48, 49, 50, 53, 55, 52, 54, 56, 57

Key b: Pitches: 33, 32

Key r: Pitches: 46, 48, 47, 49, 51, 52, 53, 55, 56, 57

Key g: Pitches: 39, 40, 42, 44, 45, 46

Key h: Pitches: 40, 41, 44, 45, 46, 48

Key c: Pitches: 35, 36, 37, 38, 40

Key o: Pitches: 46, 47, 49, 51, 48, 50, 52, 53, 54, 55

Key l: Pitches: 43, 44, 46, 47, 49, 51

Key d: Pitches: 37, 38, 39, 40

Key y: Pitches: 54, 57, 58, 61

Key n: Pitches: 45, 46, 47, 50, 49, 51, 53, 54

Key p: Pitches: 47, 48, 49, 54

Key e: Pitches: 38, 39, 40, 41, 44

Key k: Pitches: 44, 43, 45, 47

Key m: Pitches: 45, 47, 51

Key u: Pitches: 51, 54, 56, 58

Key z: Pitches: 54, 56

Key f: Pitches: 40, 39, 41, 44

Key v: Pitches: 55, 57, 58

Key q: Pitches: 51, 54

File: output_english_1_1748804484.7478225.mid (3, 0.6865079365079365) Key i: Pitches: 43, 40, 41, 42, 44, 45, 46, 47, 48

Key t: Pitches: 48, 50, 49, 51, 52, 53, 54, 56, 57, 58

Key w: Pitches: 48, 51, 56, 58

Key a: Pitches: 31, 34, 32, 33, 30, 36

Key s: Pitches: 47, 48, 49, 50, 51, 53, 55, 54, 56, 57

Key b: Pitches: 32, 33

Key r: Pitches: 46, 47, 48, 50, 52, 53, 55, 56

Key g: Pitches: 39, 40, 42, 44, 45

Key h: Pitches: 40, 41, 42, 44, 45, 46, 47

Key c: Pitches: 34, 36, 37, 32, 38, 39

Key o: Pitches: 46, 47, 49, 50, 51, 52, 53, 55, 54

Key l: Pitches: 44, 46, 45, 47, 49, 51

Key d: Pitches: 37, 38, 39, 40

Key y: Pitches: 53, 57, 58, 59

Key n: Pitches: 45, 46, 47, 50, 49, 51, 53, 54

Key p: Pitches: 46, 48, 49, 54

Key e: Pitches: 39, 38, 41, 40, 42, 44

Key k: Pitches: 44, 43, 45, 47

Key m: Pitches: 44, 48, 53

Key u: Pitches: 51, 54, 56, 57, 58

Key z: Pitches: 55, 56

Key f: Pitches: 40, 39, 41, 44

Key v: Pitches: 55, 57, 58

Key q: Pitches: 50, 54

File: output_english_2_1748804434.1069958.mid (1, 0.6812227074235808) Key t: Pitches: 54, 49, 50, 51, 53, 56, 57, 58

Key h: Pitches: 39, 41, 40, 42, 43, 44, 45, 46

Key e: Pitches: 37, 38, 39, 40, 41, 43, 42

Key l: Pitches: 43, 44, 45, 46, 47

Key a: Pitches: 31, 32, 33, 30, 29, 37

Key n: Pitches: 45, 44, 46, 47, 48, 50

Key w: Pitches: 51, 52, 53, 56, 57

Key i: Pitches: 41, 43, 44, 47

Key d: Pitches: 37, 38, 39, 40, 41

Key m: Pitches: 45

Key u: Pitches: 50, 51, 53, 56, 57

Key c: Pitches: 37, 36, 38

Key o: Pitches: 46, 48, 49, 50, 51, 53, 54

Key f: Pitches: 39, 44

Key p: Pitches: 46, 47

Key s: Pitches: 48, 49, 50, 54, 56

Key r: Pitches: 49, 47, 51, 54, 55

Key k: Pitches: 44, 45

Key g: Pitches: 41, 44, 45

Key b: Pitches: 34, 33, 36, 32, 37, 39

Key q: Pitches: 49

Key y: Pitches: 56, 57, 70, 61, 62

Key v: Pitches: 57

File: output_english_2_1748804446.052948.mid (3, 0.6899563318777293) Key t: Pitches: 54, 49, 50, 53, 55, 56, 57

Key h: Pitches: 39, 41, 40, 42, 43, 44, 45, 46

Key e: Pitches: 38, 37, 39, 40, 41, 42, 44

Key l: Pitches: 43, 44, 45, 47

Key a: Pitches: 31, 32, 33, 30, 29, 38

Key n: Pitches: 45, 44, 46, 47, 48, 50

Key w: Pitches: 51, 52, 53, 56, 55, 57

Key i: Pitches: 41, 43, 44, 47, 48

Key d: Pitches: 37, 38, 39, 40

Key m: Pitches: 44, 45

Key u: Pitches: 50, 51, 53, 56, 58

Key c: Pitches: 37, 36, 38

Key o: Pitches: 46, 48, 49, 50, 51, 52, 53, 54, 55

Key f: Pitches: 39, 44

Key p: Pitches: 46, 47

Key s: Pitches: 48, 49, 50, 54, 55, 56

Key r: Pitches: 48, 49, 52, 51, 55, 54

Key k: Pitches: 44, 46

Key g: Pitches: 41, 44

Key b: Pitches: 33, 36, 35, 37, 38, 39

Key q: Pitches: 50

Key y: Pitches: 56, 58, 61

Key v: Pitches: 57

File: output_english_3_1748804822.122112.mid (3, 0.728110599078341) Key i: Pitches: 39, 41, 43, 44, 46, 47

Key f: Pitches: 38, 39, 41, 44

Key m: Pitches: 44, 45, 48, 47, 55

Key u: Pitches: 48, 49, 56, 57, 58

Key s: Pitches: 47, 48, 50, 56, 55, 57, 58

Key c: Pitches: 36, 34

Key b: Pitches: 34, 32, 38

Key e: Pitches: 38, 39, 40, 41, 42, 43, 44

Key t: Pitches: 49, 48, 50, 51, 53, 54, 55, 56, 58

Key h: Pitches: 40, 41, 43, 44, 45, 49

Key o: Pitches: 46, 45, 48, 50, 51, 53, 52, 54, 55, 56

Key d: Pitches: 37, 38, 39, 40, 41, 42

Key l: Pitches: 44, 46, 47, 49, 51

Key v: Pitches: 51, 58

Key p: Pitches: 46, 54

Key a: Pitches: 31, 32, 33, 30, 36, 37

Key y: Pitches: 53, 57, 58

Key n: Pitches: 45, 46, 47, 49, 51, 53, 54, 55

Key g: Pitches: 39, 40, 41, 42, 45, 46

Key x: Pitches: 53

Key r: Pitches: 47, 49, 54, 55, 56, 57, 58

Key k: Pitches: 44, 46, 48

Key w: Pitches: 57

File: output_english_3_1748804834.8243206.mid (3, 0.7235023041474654) Key i: Pitches: 39, 41, 42, 43, 44, 46, 47

Key f: Pitches: 38, 39, 40, 44

Key m: Pitches: 44, 45, 48, 47, 55

Key u: Pitches: 48, 49, 56, 57, 58

Key s: Pitches: 48, 47, 50, 56, 57

Key c: Pitches: 36, 34

Key b: Pitches: 34, 32, 38

Key e: Pitches: 38, 39, 40, 41, 42, 43, 44, 45

Key t: Pitches: 49, 48, 50, 51, 53, 54, 56, 57, 58

Key h: Pitches: 41, 43, 44, 45, 47

Key o: Pitches: 46, 45, 48, 50, 53, 54, 55, 56

Key d: Pitches: 37, 38, 39, 40, 42, 43

Key l: Pitches: 44, 47, 46, 50, 51

Key v: Pitches: 51, 58

Key p: Pitches: 46, 55

Key a: Pitches: 32, 33, 31, 30, 36

Key y: Pitches: 53, 57, 58

Key n: Pitches: 45, 46, 47, 48, 50, 51, 53, 54, 56, 55

Key g: Pitches: 39, 41, 42, 45

Key x: Pitches: 53

Key r: Pitches: 47, 49, 53, 55, 56, 57, 58

Key k: Pitches: 44, 46, 48

Key w: Pitches: 58

File: output_performance_1_1748805169.7088091.mid (1, 0.7380952380952381) Key a: Pitches: 18, 31, 30, 32

Key d: Pitches: 35, 34, 37, 36

Key g: Pitches: 38, 39

Key s: Pitches: 46, 48, 47, 49

Key p: Pitches: 45, 46, 47

Key o: Pitches: 44

Key f: Pitches: 39

File: output_performance_1_1748805218.55786.mid (6, 0.6904761904761905) Key a: Pitches: 22, 30, 31

Key d: Pitches: 35, 37, 36

Key g: Pitches: 38

Key s: Pitches: 46, 45, 48, 47, 49

Key p: Pitches: 45, 44, 46

Key o: Pitches: 44

Key f: Pitches: 38, 39

File: output_performance_2_1748805761.0151947.mid (3, 0.7419354838709677) Key z: Pitches: 70, 50, 53

Key v: Pitches: 51, 48

Key b: Pitches: 32, 33

Key n: Pitches: 44, 43, 45

Key k: Pitches: 41

Key m: Pitches: 43, 44

Key f: Pitches: 38

Key j: Pitches: 41

Key l: Pitches: 43, 45

Key d: Pitches: 37

File: output_performance_2_1748805798.0441349.mid (3, 0.7741935483870968) Key z: Pitches: 70, 51, 53

Key v: Pitches: 51, 48

Key b: Pitches: 32, 33

Key n: Pitches: 44, 43, 45

Key k: Pitches: 41

Key m: Pitches: 43, 44

Key f: Pitches: 38

Key j: Pitches: 41

Key l: Pitches: 43, 44

Key d: Pitches: 37

File: output_performance_2_1748805941.4645994.mid (3, 0.8064516129032258) Key z: Pitches: 80, 50, 53

Key v: Pitches: 51, 48

Key b: Pitches: 32

Key n: Pitches: 44, 43, 45

Key k: Pitches: 41

Key m: Pitches: 43, 44, 45

Key f: Pitches: 38

Key j: Pitches: 41

Key l: Pitches: 43, 44

Key d: Pitches: 36

File: output_performance_3_1748805951.9792469.mid (4, 0.7777777777777778) Key q: Pitches: 53

Key z: Pitches: 54

Key x: Pitches: 49, 48

Key v: Pitches: 47, 48

Key c: Pitches: 32, 33

Key n: Pitches: 43, 44

File: output_performance_3_1748805957.8815727.mid (4, 0.7777777777777778) Key q: Pitches: 54

Key z: Pitches: 55

Key x: Pitches: 49, 48

Key v: Pitches: 47

Key c: Pitches: 33, 34, 32

Key n: Pitches: 43, 44

Below, you can see that clicking a bin results in different notes with multiple clicks.

In [59]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# notes now should be list of tuples (bin, note, time, velocity)
df = pd.DataFrame(notes, columns=['bin','note','time','velocity'])

# 1) Bar chart: number of events per bin
plt.figure(figsize=(10,4))
df['bin'].value_counts().sort_index().plot(kind='bar', color='C0')
plt.xlabel('Bin index')
plt.ylabel('Count')
plt.title('Event count per bin')
plt.show()

# 2) Heatmap: bin × note frequency
pivot = df.groupby(['bin','note']).size().unstack(fill_value=0)
plt.figure(figsize=(12,8))
sns.heatmap(pivot, cmap='viridis', cbar_kws={'label':'Count'})
plt.xlabel('MIDI pitch')
plt.ylabel('Bin index')
plt.title('Bin→Note mapping frequencies')
plt.show()
# --- end diagnostics ---
No description has been provided for this image
No description has been provided for this image
In [ ]:
import torch
import json

# generate the midi file
import pretty_midi
import time as time_lib
import csv


cfg = json.load(open("cfg.json"))
model = PianoGenieAutoencoder(cfg)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()

# 2) Prepare decoder state for a single stream
#    We’ll run batch_size=1 since we’re in interactive mode
h = model.dec.init_hidden(batch_size=1)
# k_prev holds the last output key index; start with SOS
k_prev = torch.full((1, 1), SOS, dtype=torch.long)

# 3) Now, each time the user presses a button, you have:
#    b_i: integer 0…7    (which button)
#    t_i: float          (absolute onset time in seconds)
#    v_i: int 1…127      (velocity)
#
# You’ll convert these into tensors, call the decoder, sample/argmax,
# and then feed that key back in as the next k_prev.


def step(b_i, t_i, v_i, k_prev, h):
    # 3a) button needs to be the *real-valued* centroid in [–1,1]
    b_real = model.quant.discrete_to_real(torch.tensor([[b_i]]))  # → shape (1,1)
    # 3b) wrap time & velocity
    t = torch.tensor([[t_i]], dtype=torch.float)
    v = torch.tensor([[v_i]], dtype=torch.float)

    # 3c) run decoder for one timestep
    with torch.no_grad():
        logits, h = model.dec(k_prev, t, b_real, v, h)
        # logits: (1,1,88)
        probs = torch.softmax(logits[0, 0], dim=-1)
        k_i = torch.multinomial(probs, num_samples=1)  # or .argmax()

    return k_i.reshape(1, 1), h, probs


# 4) Example usage:
#    Suppose the user hits button 3 at time=0.57s with velocity=90:
# k1, h, p = step(b_i=3, t_i=0.57, v_i=90, k_prev=k_prev, h=h)


# version 1
def letter_to_button_keyboard(letter):
    # Map letters on the keyboard to button indices, top row, middle row, bottom row
    top = "qwertyuiop"
    middle = "asdfghjkl"
    bottom = "zxcvbnm"
    if letter in top:
        return min(top.index(letter), 8), 40
    elif letter in middle:
        return min(middle.index(letter), 8), 80
    elif letter in bottom:
        return min(bottom.index(letter), 8), 120
    else:
        return 0, 0


# V2
def letter_to_button_26(letter):
    # Map letters to button indices for a 26-letter keyboard layout
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    if letter in alphabet:
        return alphabet.index(letter)
    else:
        return 0  # Default case for unsupported characters


def letter_to_button_26_rowwise(letter):
    # Map letters to button indices for a 26-letter keyboard layout, row-wise
    alphabet = "qwertyuiopasdfghjklzxcvbnm"
    if letter in alphabet:
        return alphabet.index(letter)
    else:
        return 0  # Default case for unsupported characters


def extract_pitches_from_midi(midi_path):
    """
    Load a MIDI file and return a list of all note pitches (as MIDI numbers).
    """
    pm = pretty_midi.PrettyMIDI(midi_path)
    pitches = []

    # Loop over all instruments (tracks) in the file
    for instrument in pm.instruments:
        # Skip drum tracks if you only care about pitched notes:
        if instrument.is_drum:
            continue
        # Loop over each Note object in this instrument
        for note in instrument.notes:
            pitches.append(note.pitch)
    return pitches


# metric one
def infer_key_from_pitches(pitches):
    """
    Given a list of MIDI pitches (0–127), figure out which major key
    (0=C major, 1=C♯ major, …, 11=B major) has the most pitches landing
    on its diatonic pitch-classes.

    Returns (best_key_root, is_major=True).
    """
    # Define diatonic sets for all 12 major keys:
    # pitch-class → integer 0…11 where C=0, C♯=1, …, B=11
    major_scales = {
        0: {0, 2, 4, 5, 7, 9, 11},  # C major
        1: {1, 3, 5, 6, 8, 10, 0},  # C♯ major
        2: {2, 4, 6, 7, 9, 11, 1},  # D major
        3: {3, 5, 7, 8, 10, 0, 2},  # E♭/D♯ major
        4: {4, 6, 8, 9, 11, 1, 3},  # E major
        5: {5, 7, 9, 10, 0, 2, 4},  # F major
        6: {6, 8, 10, 11, 1, 3, 5},  # F♯ major
        7: {7, 9, 11, 0, 2, 4, 6},  # G major
        8: {8, 10, 0, 1, 3, 5, 7},  # A♭ major
        9: {9, 11, 1, 2, 4, 6, 8},  # A major
        10: {10, 0, 2, 3, 5, 7, 9},  # B♭ major
        11: {11, 1, 3, 4, 6, 8, 10},  # B major
    }

    # Build a histogram of pitch-classes
    pc_counts = [0] * 12
    for p in pitches:
        pc = p % 12
        pc_counts[pc] += 1

    # For each major key, count how many pitches are in its diatonic set
    best_key, best_count = None, -1
    for key_root, scale_set in major_scales.items():
        count_in_scale = sum(pc_counts[pc] for pc in scale_set)
        if count_in_scale > best_count:
            best_key, best_count = key_root, count_in_scale

    return best_key  # integer 0…11 (C major=0, C♯ major=1, etc.)


def compute_in_scale_ratio(notes, key_root=None):
    """
    notes: list of (midi_pitch, onset_time, velocity)
    key_root: if you already know the key (0…11), pass it in;
              otherwise, set key_root=None to auto-infer.
    Returns: (key_root, in_scale_ratio)
    """
    # 1) Extract just the pitches
    pitches = [p for p in notes]
    if key_root is None:
        key_root = infer_key_from_pitches(pitches)

    # 2) Get diatonic set for that key
    #    (using the same major_scales dict from above)
    major_scales = {
        0: {0, 2, 4, 5, 7, 9, 11},
        1: {1, 3, 5, 6, 8, 10, 0},
        2: {2, 4, 6, 7, 9, 11, 1},
        3: {3, 5, 7, 8, 10, 0, 2},
        4: {4, 6, 8, 9, 11, 1, 3},
        5: {5, 7, 9, 10, 0, 2, 4},
        6: {6, 8, 10, 11, 1, 3, 5},
        7: {7, 9, 11, 0, 2, 4, 6},
        8: {8, 10, 0, 1, 3, 5, 7},
        9: {9, 11, 1, 2, 4, 6, 8},
        10: {10, 0, 2, 3, 5, 7, 9},
        11: {11, 1, 3, 4, 6, 8, 10},
    }
    scale_set = major_scales[key_root]

    # 3) Count how many pitches lie in that set
    in_scale = sum(1 for p in pitches if (p % 12) in scale_set)
    total = len(pitches)
    ratio = in_scale / total if total > 0 else 0.0

    return key_root, ratio


def eval_map_key_to_pitches(pitches, filename):

    filename_no_ext = filename.split(".")[0]

    letters = []

    with open(filename, "r") as f:
        # print(f"Reading {filename} for letter-to-pitch mapping...")
        reader = csv.reader(f)
        # skip header
        next(reader)
        for row in reader:
            # print(row)
            letter, time, wpm = row[0], row[1], row[2]

            if not letter or not time or len(letter) != 1 or not letter.isalpha():
                # print("Skipping empty row")
                continue

            if wpm == "Infinity" or wpm == "inf":
                wpm = 80

            # check if wpm is numeric
            # if not wpm.isnumeric():
            #     print(f"Skipping non-numeric wpm: {wpm}")
            #     continue

            time = float(time)
            wpm = float(wpm)
            letter = letter.lower()

            letters.append(letter)

    letter_to_pitch_counts = {}

    if len(letters) != len(pitches):
        raise ValueError(
            f"Mismatch: {len(letters)} letters from CSV but {len(pitches)} notes provided."
        )

    for letter, pitch in zip(letters, pitches):
        # Ensure there’s a sub-dictionary for this letter
        if letter not in letter_to_pitch_counts:
            letter_to_pitch_counts[letter] = {}

        # Increment the count for this pitch under that letter
        subdict = letter_to_pitch_counts[letter]
        subdict[pitch] = subdict.get(pitch, 0) + 1

    return letter_to_pitch_counts


def baseline(filename):
    notes = []

    filename_no_ext = filename.split(".")[0]

    with open(filename, "r") as f:
        reader = csv.reader(f)
        # skip header
        next(reader)
        for row in reader:
            print(row)
            letter, time, wpm = row[0], row[1], row[2]

            if not letter or not time or len(letter) != 1 or not letter.isalpha():
                # print("Skipping empty row")
                continue

            if wpm == "Infinity" or wpm == "inf":
                wpm = 80

            # check if wpm is numeric
            # if not wpm.isnumeric():
            #     print(f"Skipping non-numeric wpm: {wpm}")
            #     continue

            time = float(time)
            wpm = float(wpm)

            print(letter, time, wpm)
            # convert letter to button index and velocity
            letter = letter.lower()

            # velocity is used from mapping function, mapping is based on keyboard layout
            button = letter_to_button_26_rowwise(letter)
            velocity = int(wpm) * 2
            # ensure velocity is in range 1-127
            velocity = max(1, min(127, velocity))

            # do a naive mapping of button to key index, without using step
            key_to_add = (
                button + 21
            )  # C4 is MIDI note 60, so we offset by 21 to get the right range
            notes.append((key_to_add, time, velocity))
    return notes


# read typing_intervals.csv


# 4) Example usage:
#    Suppose the user hits button 3 at time=0.57s with velocity=90:
# k1, h, p = step(b_i=3, t_i=0.57, v_i=90, k_prev=k_prev, h=h)


# generate midi from timestamps
notes = []

filename = "english_3.csv"
filename_no_ext = filename.split(".")[0]

with open(filename, "r") as f:
    reader = csv.reader(f)
    # skip header
    next(reader)
    for row in reader:
        print(row)
        letter, time, wpm = row[0], row[1], row[2]

        if not letter or not time or len(letter) != 1 or not letter.isalpha():
            # print("Skipping empty row")
            continue

        if wpm == "Infinity" or wpm == "inf":
            wpm = 80

        # check if wpm is numeric
        # if not wpm.isnumeric():
        #     print(f"Skipping non-numeric wpm: {wpm}")
        #     continue

        time = float(time)
        wpm = float(wpm)

        print(letter, time, wpm)
        # convert letter to button index and velocity
        letter = letter.lower()

        # velocity is used from mapping function, mapping is based on keyboard layout
        button = letter_to_button_26(letter)
        velocity = int(wpm) * 2
        # ensure velocity is in range 1-127
        velocity = max(1, min(127, velocity))
        k_prev, h, probs = step(b_i=button, t_i=time, v_i=velocity, k_prev=k_prev, h=h)
        notes.append((k_prev.item(), time, velocity))
print(notes)


# use baseline

use_baseline = True

if use_baseline:
    notes = baseline(filename)

pm = pretty_midi.PrettyMIDI()
instr = pretty_midi.Instrument(program=0)

for i, (note, onset, vel) in enumerate(notes):
    # define a duration for each note
    if i + 1 < len(notes):
        end = notes[i + 1][1]
    else:
        end = onset + 0.5
    pm_note = pretty_midi.Note(velocity=vel, pitch=note, start=onset, end=end)
    print(pm_note)
    instr.notes.append(pm_note)

pm.instruments.append(instr)
filename = f"output_{filename_no_ext}_{time_lib.time()}.mid"

if use_baseline:
    filename = f"baseline_{filename_no_ext}_{time_lib.time()}.mid"
pm.write(filename)


# run evals
import glob

# Look for files ending in .mid or .midi in the current directory
midi_patterns = ["*.mid", "*.midi"]
all_files = []
for pattern in midi_patterns:
    all_files.extend(glob.glob(pattern))

if not all_files:
    print("No MIDI files found in the current directory.")


# For each MIDI file, extract pitches and print a summary
for midi_file in sorted(all_files):
    pitches = []
    try:
        pitches = extract_pitches_from_midi(midi_file)
    except Exception as e:
        print(f"Error reading {midi_file}: {e}")
        continue
    print(f"File: {midi_file}")
    print(compute_in_scale_ratio(pitches))

    input_name = ""
    if "performance_1" in midi_file:
        input_name = "performance_1.csv"
    elif "performance_2" in midi_file:
        input_name = "performance_2.csv"
    elif "performance_3" in midi_file:
        input_name = "performance_3.csv"
    elif "english_1" in midi_file:
        input_name = "english_1.csv"
    elif "english_2" in midi_file:
        input_name = "english_2.csv"
    elif "english_3" in midi_file:
        input_name = "english_3.csv"

    key_to_pitches = eval_map_key_to_pitches(pitches, input_name)

    for key, pitches in key_to_pitches.items():
        print(f"Key {key}:")
        print("  Pitches:", ", ".join(str(p) for p in pitches))
        print()